Merge branch 'master' of github.com:yandex/ClickHouse into DOCAPI-4293

This commit is contained in:
BayoNet 2019-05-30 18:18:25 +03:00
commit b6dbe984da
261 changed files with 7891 additions and 1667 deletions

View File

@ -1,8 +1,59 @@
## ClickHouse release 19.6.2.11, 2019-05-13
### New Features
* TTL expressions for columns and tables. [#4212](https://github.com/yandex/ClickHouse/pull/4212) ([Anton Popov](https://github.com/CurtizJ))
* Added support for `brotli` compression for HTTP responses (Accept-Encoding: br) [#4388](https://github.com/yandex/ClickHouse/pull/4388) ([Mikhail](https://github.com/fandyushin))
* Added new function `isValidUTF8` for checking whether a set of bytes is correctly utf-8 encoded. [#4934](https://github.com/yandex/ClickHouse/pull/4934) ([Danila Kutenin](https://github.com/danlark1))
* Add new load balancing policy `first_or_random` which sends queries to the first specified host and if it's inaccessible send queries to random hosts of shard. Useful for cross-replication topology setups. [#5012](https://github.com/yandex/ClickHouse/pull/5012) ([nvartolomei](https://github.com/nvartolomei))
### Experimental Features
* Add setting `index_granularity_bytes` (adaptive index granularity) for MergeTree* tables family. [#4826](https://github.com/yandex/ClickHouse/pull/4826) ([alesapin](https://github.com/alesapin))
### Improvements
* Added support for non-constant and negative size and length arguments for function `substringUTF8`. [#4989](https://github.com/yandex/ClickHouse/pull/4989) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Disable push-down to right table in left join, left table in right join, and both tables in full join. This fixes wrong JOIN results in some cases. [#4846](https://github.com/yandex/ClickHouse/pull/4846) ([Ivan](https://github.com/abyss7))
* `clickhouse-copier`: auto upload task configuration from `--task-file` option [#4876](https://github.com/yandex/ClickHouse/pull/4876) ([proller](https://github.com/proller))
* Added typos handler for storage factory and table functions factory. [#4891](https://github.com/yandex/ClickHouse/pull/4891) ([Danila Kutenin](https://github.com/danlark1))
* Support asterisks and qualified asterisks for multiple joins without subqueries [#4898](https://github.com/yandex/ClickHouse/pull/4898) ([Artem Zuikov](https://github.com/4ertus2))
* Make missing column error message more user friendly. [#4915](https://github.com/yandex/ClickHouse/pull/4915) ([Artem Zuikov](https://github.com/4ertus2))
### Performance Improvements
* Significant speedup of ASOF JOIN [#4924](https://github.com/yandex/ClickHouse/pull/4924) ([Martijn Bakker](https://github.com/Gladdy))
### Backward Incompatible Changes
* HTTP header `Query-Id` was renamed to `X-ClickHouse-Query-Id` for consistency. [#4972](https://github.com/yandex/ClickHouse/pull/4972) ([Mikhail](https://github.com/fandyushin))
### Bug Fixes
* Fixed potential null pointer dereference in `clickhouse-copier`. [#4900](https://github.com/yandex/ClickHouse/pull/4900) ([proller](https://github.com/proller))
* Fixed error on query with JOIN + ARRAY JOIN [#4938](https://github.com/yandex/ClickHouse/pull/4938) ([Artem Zuikov](https://github.com/4ertus2))
* Fixed hanging on start of the server when a dictionary depends on another dictionary via a database with engine=Dictionary. [#4962](https://github.com/yandex/ClickHouse/pull/4962) ([Vitaly Baranov](https://github.com/vitlibar))
* Partially fix distributed_product_mode = local. It's possible to allow columns of local tables in where/having/order by/... via table aliases. Throw exception if table does not have alias. There's not possible to access to the columns without table aliases yet. [#4986](https://github.com/yandex/ClickHouse/pull/4986) ([Artem Zuikov](https://github.com/4ertus2))
* Fix potentially wrong result for `SELECT DISTINCT` with `JOIN` [#5001](https://github.com/yandex/ClickHouse/pull/5001) ([Artem Zuikov](https://github.com/4ertus2))
### Build/Testing/Packaging Improvements
* Fixed test failures when running clickhouse-server on different host [#4713](https://github.com/yandex/ClickHouse/pull/4713) ([Vasily Nemkov](https://github.com/Enmk))
* clickhouse-test: Disable color control sequences in non tty environment. [#4937](https://github.com/yandex/ClickHouse/pull/4937) ([alesapin](https://github.com/alesapin))
* clickhouse-test: Allow use any test database (remove `test.` qualification where it possible) [#5008](https://github.com/yandex/ClickHouse/pull/5008) ([proller](https://github.com/proller))
* Fix ubsan errors [#5037](https://github.com/yandex/ClickHouse/pull/5037) ([Vitaly Baranov](https://github.com/vitlibar))
* Yandex LFAlloc was added to ClickHouse to allocate MarkCache and UncompressedCache data in different ways to catch segfaults more reliable [#4995](https://github.com/yandex/ClickHouse/pull/4995) ([Danila Kutenin](https://github.com/danlark1))
* Python util to help with backports and changelogs. [#4949](https://github.com/yandex/ClickHouse/pull/4949) ([Ivan](https://github.com/abyss7))
## ClickHouse release 19.5.4.22, 2019-05-13
### Bug fixes
* Fixed possible crash in bitmap* functions [#5220](https://github.com/yandex/ClickHouse/pull/5220) [#5228](https://github.com/yandex/ClickHouse/pull/5228) ([Andy Yang](https://github.com/andyyzh))
* Fixed very rare data race condition that could happen when executing a query with UNION ALL involving at least two SELECTs from system.columns, system.tables, system.parts, system.parts_tables or tables of Merge family and performing ALTER of columns of the related tables concurrently. [#5189](https://github.com/yandex/ClickHouse/pull/5189) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Fixed error `Set for IN is not created yet in case of using single LowCardinality column in the left part of IN`. This error happened if LowCardinality column was the part of primary key. #5031 [#5154](https://github.com/yandex/ClickHouse/pull/5154) ([Nikolai Kochetov](https://github.com/KochetovNicolai))
* Modification of retention function: If a row satisfies both the first and NTH condition, only the first satisfied condition is added to the data state. Now all conditions that satisfy in a row of data are added to the data state. [#5119](https://github.com/yandex/ClickHouse/pull/5119) ([小路](https://github.com/nicelulu))
## ClickHouse release 19.5.3.8, 2019-04-18
### Bug fixes
* Fixed type of setting `max_partitions_per_insert_block` from boolean to UInt64. [#5028](https://github.com/yandex/ClickHouse/pull/5028) ([Mohammad Hossein Sekhavat](https://github.com/mhsekhavat))
## ClickHouse release 19.5.2.6, 2019-04-15
### New Features
@ -294,7 +345,7 @@
* Added support of `Nullable` types in `mysql` table function. [#4198](https://github.com/yandex/ClickHouse/pull/4198) ([Emmanuel Donin de Rosière](https://github.com/edonin))
* Support for arbitrary constant expressions in `LIMIT` clause. [#4246](https://github.com/yandex/ClickHouse/pull/4246) ([k3box](https://github.com/k3box))
* Added `topKWeighted` aggregate function that takes additional argument with (unsigned integer) weight. [#4245](https://github.com/yandex/ClickHouse/pull/4245) ([Andrew Golman](https://github.com/andrewgolman))
* `StorageJoin` now supports `join_overwrite` setting that allows overwriting existing values of the same key. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird)
* `StorageJoin` now supports `join_any_take_last_row` setting that allows overwriting existing values of the same key. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird)
* Added function `toStartOfInterval`. [#4304](https://github.com/yandex/ClickHouse/pull/4304) ([Vitaly Baranov](https://github.com/vitlibar))
* Added `RowBinaryWithNamesAndTypes` format. [#4200](https://github.com/yandex/ClickHouse/pull/4200) ([Oleg V. Kozlyuk](https://github.com/DarkWanderer))
* Added `IPv4` and `IPv6` data types. More effective implementations of `IPv*` functions. [#3669](https://github.com/yandex/ClickHouse/pull/3669) ([Vasily Nemkov](https://github.com/Enmk))

View File

@ -1,3 +1,53 @@
## ClickHouse release 19.6.2.11, 2019-05-13
### Новые возможности
* TTL выражения, позволяющие настроить время жизни и автоматическую очистку данных в таблице или в отдельных её столбцах. [#4212](https://github.com/yandex/ClickHouse/pull/4212) ([Anton Popov](https://github.com/CurtizJ))
* Добавлена поддержка алгоритма сжатия `brotli` в HTTP ответах (`Accept-Encoding: br`). Для тела POST запросов, эта возможность уже существовала. [#4388](https://github.com/yandex/ClickHouse/pull/4388) ([Mikhail](https://github.com/fandyushin))
* Добавлена функция `isValidUTF8` для проверки, содержит ли строка валидные данные в кодировке UTF-8. [#4934](https://github.com/yandex/ClickHouse/pull/4934) ([Danila Kutenin](https://github.com/danlark1))
* Добавлены новое правило балансировки (`load_balancing`) `first_or_random` по которому запросы посылаются на первый заданый хост и если он недоступен - на случайные хосты шарда. Полезно для топологий с кросс-репликацией. [#5012](https://github.com/yandex/ClickHouse/pull/5012) ([nvartolomei](https://github.com/nvartolomei))
### Эксперемннтальные возможности
* Добавлена настройка `index_granularity_bytes` (адаптивная гранулярность индекса) для таблиц семейства MergeTree* . [#4826](https://github.com/yandex/ClickHouse/pull/4826) ([alesapin](https://github.com/alesapin))
### Улучшения
* Добавлена поддержка для не константных и отрицательных значений аргументов смещения и длины для функции `substringUTF8`. [#4989](https://github.com/yandex/ClickHouse/pull/4989) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Отключение push-down в правую таблицы в left join, левую таблицу в right join, и в обе таблицы в full join. Это исправляет неправильные JOIN результаты в некоторых случаях. [#4846](https://github.com/yandex/ClickHouse/pull/4846) ([Ivan](https://github.com/abyss7))
* `clickhouse-copier`: Автоматическая загрузка конфигурации задачи в zookeeper из `--task-file` опции [#4876](https://github.com/yandex/ClickHouse/pull/4876) ([proller](https://github.com/proller))
* Добавлены подсказки с учётом опечаток для имён движков таблиц и табличных функций. [#4891](https://github.com/yandex/ClickHouse/pull/4891) ([Danila Kutenin](https://github.com/danlark1))
* Поддержка выражений `select *` и `select tablename.*` для множественных join без подзапросов [#4898](https://github.com/yandex/ClickHouse/pull/4898) ([Artem Zuikov](https://github.com/4ertus2))
* Сообщения об ошибках об отсутствующих столбцах стали более понятными. [#4915](https://github.com/yandex/ClickHouse/pull/4915) ([Artem Zuikov](https://github.com/4ertus2))
### Улучшение производительности
* Существенное ускорение ASOF JOIN [#4924](https://github.com/yandex/ClickHouse/pull/4924) ([Martijn Bakker](https://github.com/Gladdy))
### Обратно несовместимые изменения
* HTTP заголовок `Query-Id` переименован в `X-ClickHouse-Query-Id` для соответствия. [#4972](https://github.com/yandex/ClickHouse/pull/4972) ([Mikhail](https://github.com/fandyushin))
### Исправления ошибок
* Исправлены возможные разыменования нулевого указателя в `clickhouse-copier`. [#4900](https://github.com/yandex/ClickHouse/pull/4900) ([proller](https://github.com/proller))
* Исправлены ошибки в запросах с JOIN + ARRAY JOIN [#4938](https://github.com/yandex/ClickHouse/pull/4938) ([Artem Zuikov](https://github.com/4ertus2))
* Исправлено зависание на старте сервера если внешний словарь зависит от другого словаря через использование таблицы из БД с движком `Dictionary`. [#4962](https://github.com/yandex/ClickHouse/pull/4962) ([Vitaly Baranov](https://github.com/vitlibar))
* При использовании `distributed_product_mode = 'local'` корректно работает использование столбцов локальных таблиц в where/having/order by/... через табличные алиасы. Выкидывает исключение если таблица не имеет алиас. Доступ к столбцам без алиасов пока не возможен. [#4986](https://github.com/yandex/ClickHouse/pull/4986) ([Artem Zuikov](https://github.com/4ertus2))
* Исправлен потенциально некорректный результат для `SELECT DISTINCT` с `JOIN` [#5001](https://github.com/yandex/ClickHouse/pull/5001) ([Artem Zuikov](https://github.com/4ertus2))
### Улучшения сборки/тестирования/пакетирования
* Исправлена неработоспособность тестов, если `clickhouse-server` запущен на удалённом хосте [#4713](https://github.com/yandex/ClickHouse/pull/4713) ([Vasily Nemkov](https://github.com/Enmk))
* `clickhouse-test`: Отключена раскраска результата, если команда запускается не в терминале. [#4937](https://github.com/yandex/ClickHouse/pull/4937) ([alesapin](https://github.com/alesapin))
* `clickhouse-test`: Возможность использования не только базы данных test [#5008](https://github.com/yandex/ClickHouse/pull/5008) ([proller](https://github.com/proller))
* Исправлены ошибки при запуске тестов под UBSan [#5037](https://github.com/yandex/ClickHouse/pull/5037) ([Vitaly Baranov](https://github.com/vitlibar))
* Добавлен аллокатор Yandex LFAlloc для аллоцирования MarkCache и UncompressedCache данных разными способами для более надежного отлавливания проездов по памяти [#4995](https://github.com/yandex/ClickHouse/pull/4995) ([Danila Kutenin](https://github.com/danlark1))
* Утилита для упрощения бэкпортирования изменений в старые релизы и составления changelogs. [#4949](https://github.com/yandex/ClickHouse/pull/4949) ([Ivan](https://github.com/abyss7))
## ClickHouse release 19.5.4.22, 2019-05-13
### Исправления ошибок
* Исправлены возможные падения в bitmap* функциях [#5220](https://github.com/yandex/ClickHouse/pull/5220) [#5228](https://github.com/yandex/ClickHouse/pull/5228) ([Andy Yang](https://github.com/andyyzh))
* Исправлен очень редкий data race condition который мог произойти при выполнении запроса с UNION ALL включающего минимум два SELECT из таблиц system.columns, system.tables, system.parts, system.parts_tables или таблиц семейства Merge и одновременно выполняющихся запросов ALTER столбцов соответствующих таблиц. [#5189](https://github.com/yandex/ClickHouse/pull/5189) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Исправлена ошибка `Set for IN is not created yet in case of using single LowCardinality column in the left part of IN`. Эта ошибка возникала когда LowCardinality столбец была частью primary key. #5031 [#5154](https://github.com/yandex/ClickHouse/pull/5154) ([Nikolai Kochetov](https://github.com/KochetovNicolai))
* Исправление функции retention: только первое соответствующее условие добавлялось в состояние данных. Сейчас все условия которые удовлетворяют в строке данных добавляются в состояние. [#5119](https://github.com/yandex/ClickHouse/pull/5119) ([小路](https://github.com/nicelulu))
## ClickHouse release 19.5.3.8, 2019-04-18
### Исправления ошибок
@ -286,7 +336,7 @@
* Добавлена поддержка `Nullable` типов в табличной функции `mysql`. [#4198](https://github.com/yandex/ClickHouse/pull/4198) ([Emmanuel Donin de Rosière](https://github.com/edonin))
* Добавлена поддержка произвольных константных выражений в секции `LIMIT`. [#4246](https://github.com/yandex/ClickHouse/pull/4246) ([k3box](https://github.com/k3box))
* Добавлена агрегатная функция `topKWeighted` - вариант `topK`, позволяющий задавать (целый неотрицательный) вес добавляемого значения. [#4245](https://github.com/yandex/ClickHouse/pull/4245) ([Andrew Golman](https://github.com/andrewgolman))
* Движок `Join` теперь поддерживает настройку `join_overwrite`, которая позволяет перезаписывать значения для существующих ключей. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird))
* Движок `Join` теперь поддерживает настройку `join_any_take_last_row`, которая позволяет перезаписывать значения для существующих ключей. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird))
* Добавлена функция `toStartOfInterval`. [#4304](https://github.com/yandex/ClickHouse/pull/4304) ([Vitaly Baranov](https://github.com/vitlibar))
* Добавлена функция `toStartOfTenMinutes`. [#4298](https://github.com/yandex/ClickHouse/pull/4298) ([Vitaly Baranov](https://github.com/vitlibar))
* Добавлен формат `RowBinaryWithNamesAndTypes`. [#4200](https://github.com/yandex/ClickHouse/pull/4200) ([Oleg V. Kozlyuk](https://github.com/DarkWanderer))

View File

@ -49,19 +49,20 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
endif ()
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
if (WEVERYTHING)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-std-move-in-c++11")
endif ()
endif ()
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra-semi-stmt -Wshadow-field -Wstring-plus-int -Wempty-init-stmt")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wshadow-field -Wstring-plus-int")
if(NOT APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra-semi-stmt -Wempty-init-stmt")
endif()
endif ()
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
if (WEVERYTHING)
if (WEVERYTHING AND NOT APPLE)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ctad-maybe-unsupported")
endif ()
endif ()

View File

@ -1,11 +1,11 @@
# This strings autochanged from release_lib.sh:
set(VERSION_REVISION 54420)
set(VERSION_REVISION 54421)
set(VERSION_MAJOR 19)
set(VERSION_MINOR 8)
set(VERSION_MINOR 9)
set(VERSION_PATCH 1)
set(VERSION_GITHASH a76e504f45ff4a74e8c492bd269f022352d5f6d9)
set(VERSION_DESCRIBE v19.8.1.1-testing)
set(VERSION_STRING 19.8.1.1)
set(VERSION_GITHASH 0c2aa460651a462f14efc7e995840a244531d373)
set(VERSION_DESCRIBE v19.9.1.1-testing)
set(VERSION_STRING 19.9.1.1)
# end of autochange
set(VERSION_EXTRA "" CACHE STRING "")

View File

@ -325,8 +325,8 @@ private:
double seconds = watch.elapsedSeconds();
std::lock_guard lock(mutex);
info_per_interval.add(seconds, progress.rows, progress.bytes, info.rows, info.bytes);
info_total.add(seconds, progress.rows, progress.bytes, info.rows, info.bytes);
info_per_interval.add(seconds, progress.read_rows, progress.read_bytes, info.rows, info.bytes);
info_total.add(seconds, progress.read_rows, progress.read_bytes, info.rows, info.bytes);
}

View File

@ -435,7 +435,7 @@ private:
#if USE_READLINE
int res = read_history(history_file.c_str());
if (res)
throwFromErrno("Cannot read history from file " + history_file, ErrorCodes::CANNOT_READ_HISTORY);
std::cerr << "Cannot read history from file " + history_file + ": "+ errnoToString(ErrorCodes::CANNOT_READ_HISTORY);
#endif
}
else /// Create history file.
@ -612,7 +612,7 @@ private:
#if USE_READLINE && HAVE_READLINE_HISTORY
if (!history_file.empty() && append_history(1, history_file.c_str()))
throwFromErrno("Cannot append history to file " + history_file, ErrorCodes::CANNOT_APPEND_HISTORY);
std::cerr << "Cannot append history to file " + history_file + ": " + errnoToString(ErrorCodes::CANNOT_APPEND_HISTORY);
#endif
prev_input = input;
@ -866,7 +866,7 @@ private:
std::cout << std::endl
<< processed_rows << " rows in set. Elapsed: " << watch.elapsedSeconds() << " sec. ";
if (progress.rows >= 1000)
if (progress.read_rows >= 1000)
writeFinalProgress();
std::cout << std::endl << std::endl;
@ -1420,23 +1420,23 @@ private:
<< " Progress: ";
message
<< formatReadableQuantity(progress.rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.bytes);
<< formatReadableQuantity(progress.read_rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes);
size_t elapsed_ns = watch.elapsed();
if (elapsed_ns)
message << " ("
<< formatReadableQuantity(progress.rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
<< formatReadableQuantity(progress.read_rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
else
message << ". ";
written_progress_chars = message.count() - prefix_size - (increment % 8 == 7 ? 10 : 13); /// Don't count invisible output (escape sequences).
/// If the approximate number of rows to process is known, we can display a progress bar and percentage.
if (progress.total_rows > 0)
if (progress.total_rows_to_read > 0)
{
size_t total_rows_corrected = std::max(progress.rows, progress.total_rows);
size_t total_rows_corrected = std::max(progress.read_rows, progress.total_rows_to_read);
/// To avoid flicker, display progress bar only if .5 seconds have passed since query execution start
/// and the query is less than halfway done.
@ -1444,7 +1444,7 @@ private:
if (elapsed_ns > 500000000)
{
/// Trigger to start displaying progress bar. If query is mostly done, don't display it.
if (progress.rows * 2 < total_rows_corrected)
if (progress.read_rows * 2 < total_rows_corrected)
show_progress_bar = true;
if (show_progress_bar)
@ -1452,7 +1452,7 @@ private:
ssize_t width_of_progress_bar = static_cast<ssize_t>(terminal_size.ws_col) - written_progress_chars - strlen(" 99%");
if (width_of_progress_bar > 0)
{
std::string bar = UnicodeBar::render(UnicodeBar::getWidth(progress.rows, 0, total_rows_corrected, width_of_progress_bar));
std::string bar = UnicodeBar::render(UnicodeBar::getWidth(progress.read_rows, 0, total_rows_corrected, width_of_progress_bar));
message << "\033[0;32m" << bar << "\033[0m";
if (width_of_progress_bar > static_cast<ssize_t>(bar.size() / UNICODE_BAR_CHAR_SIZE))
message << std::string(width_of_progress_bar - bar.size() / UNICODE_BAR_CHAR_SIZE, ' ');
@ -1461,7 +1461,7 @@ private:
}
/// Underestimate percentage a bit to avoid displaying 100%.
message << ' ' << (99 * progress.rows / total_rows_corrected) << '%';
message << ' ' << (99 * progress.read_rows / total_rows_corrected) << '%';
}
message << ENABLE_LINE_WRAPPING;
@ -1474,14 +1474,14 @@ private:
void writeFinalProgress()
{
std::cout << "Processed "
<< formatReadableQuantity(progress.rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.bytes);
<< formatReadableQuantity(progress.read_rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes);
size_t elapsed_ns = watch.elapsed();
if (elapsed_ns)
std::cout << " ("
<< formatReadableQuantity(progress.rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
<< formatReadableQuantity(progress.read_rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
else
std::cout << ". ";
}

View File

@ -1,6 +1,7 @@
#include "PerformanceTest.h"
#include <Core/Types.h>
#include <Common/CpuId.h>
#include <common/getMemoryAmount.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/ReadHelpers.h>
@ -71,6 +72,7 @@ bool PerformanceTest::checkPreconditions() const
Strings preconditions;
config->keys("preconditions", preconditions);
size_t table_precondition_index = 0;
size_t cpu_precondition_index = 0;
for (const std::string & precondition : preconditions)
{
@ -136,6 +138,30 @@ bool PerformanceTest::checkPreconditions() const
return false;
}
}
if (precondition == "cpu")
{
std::string precondition_key = "preconditions.cpu[" + std::to_string(cpu_precondition_index++) + "]";
std::string flag_to_check = config->getString(precondition_key);
#define CHECK_CPU_PRECONDITION(OP) \
if (flag_to_check == #OP) \
{ \
if (!Cpu::CpuFlagsCache::have_##OP) \
{ \
LOG_WARNING(log, "CPU doesn't support " << #OP); \
return false; \
} \
} else
CPU_ID_ENUMERATE(CHECK_CPU_PRECONDITION)
{
LOG_WARNING(log, "CPU doesn't support " << flag_to_check);
return false;
}
#undef CHECK_CPU_PRECONDITION
}
}
return true;

View File

@ -13,7 +13,7 @@ void checkFulfilledConditionsAndUpdate(
TestStats & statistics, TestStopConditions & stop_conditions,
InterruptListener & interrupt_listener)
{
statistics.add(progress.rows, progress.bytes);
statistics.add(progress.read_rows, progress.read_bytes);
stop_conditions.reportRowsRead(statistics.total_rows_read);
stop_conditions.reportBytesReadUncompressed(statistics.total_bytes_read);

View File

@ -8,6 +8,8 @@ set(CLICKHOUSE_SERVER_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/RootRequestHandler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Server.cpp
${CMAKE_CURRENT_SOURCE_DIR}/TCPHandler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandlerFactory.cpp
)
set(CLICKHOUSE_SERVER_LINK PRIVATE clickhouse_dictionaries clickhouse_common_io PUBLIC daemon PRIVATE clickhouse_storages_system clickhouse_functions clickhouse_aggregate_functions clickhouse_table_functions ${Poco_Net_LIBRARY})

View File

@ -453,30 +453,6 @@ void HTTPHandler::processQuery(
return false;
};
/// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
std::string full_query;
/// Support for "external data for query processing".
if (startsWith(request.getContentType().data(), "multipart/form-data"))
{
ExternalTablesHandler handler(context, params);
params.load(request, istr, handler);
/// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
reserved_param_suffixes.emplace_back("_format");
reserved_param_suffixes.emplace_back("_types");
reserved_param_suffixes.emplace_back("_structure");
/// Params are of both form params POST and uri (GET params)
for (const auto & it : params)
if (it.first == "query")
full_query += it.second;
in = std::make_unique<ReadBufferFromString>(full_query);
}
else
in = std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
/// Settings can be overridden in the query.
/// Some parameters (database, default_format, everything used in the code above) do not
/// belong to the Settings class.
@ -497,30 +473,63 @@ void HTTPHandler::processQuery(
settings.readonly = 2;
}
SettingsChanges settings_changes;
for (auto it = params.begin(); it != params.end(); ++it)
bool isExternalData = startsWith(request.getContentType().data(), "multipart/form-data");
if (isExternalData)
{
if (it->first == "database")
/// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
reserved_param_suffixes.reserve(3);
/// It is a bug and ambiguity with `date_time_input_format` and `low_cardinality_allow_in_native_format` formats/settings.
reserved_param_suffixes.emplace_back("_format");
reserved_param_suffixes.emplace_back("_types");
reserved_param_suffixes.emplace_back("_structure");
}
SettingsChanges settings_changes;
for (const auto & [key, value] : params)
{
if (key == "database")
{
context.setCurrentDatabase(it->second);
context.setCurrentDatabase(value);
}
else if (it->first == "default_format")
else if (key == "default_format")
{
context.setDefaultFormat(it->second);
context.setDefaultFormat(value);
}
else if (param_could_be_skipped(it->first))
else if (param_could_be_skipped(key))
{
}
else
{
/// All other query parameters are treated as settings.
settings_changes.push_back({it->first, it->second});
settings_changes.push_back({key, value});
}
}
/// For external data we also want settings
context.checkSettingsConstraints(settings_changes);
context.applySettingsChanges(settings_changes);
/// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
std::string full_query;
/// Support for "external data for query processing".
if (isExternalData)
{
ExternalTablesHandler handler(context, params);
params.load(request, istr, handler);
/// Params are of both form params POST and uri (GET params)
for (const auto & it : params)
if (it.first == "query")
full_query += it.second;
in = std::make_unique<ReadBufferFromString>(full_query);
}
else
in = std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
/// HTTP response compression is turned on only if the client signalled that they support it
/// (using Accept-Encoding header) and 'enable_http_compression' setting is turned on.
used_output.out->setCompression(client_supports_http_compression && settings.enable_http_compression);

View File

@ -1,17 +1,16 @@
#include "InterserverIOHTTPHandler.h"
#include <Poco/Net/HTTPBasicCredentials.h>
#include <Poco/Net/HTTPServerRequest.h>
#include <Poco/Net/HTTPServerResponse.h>
#include <common/logger_useful.h>
#include <Common/HTMLForm.h>
#include <Common/setThreadName.h>
#include <Compression/CompressedWriteBuffer.h>
#include <IO/ReadBufferFromIStream.h>
#include <IO/WriteBufferFromHTTPServerResponse.h>
#include <Interpreters/InterserverIOHandler.h>
#include "InterserverIOHTTPHandler.h"
#include "IServer.h"
namespace DB
{
@ -50,7 +49,7 @@ std::pair<String, bool> InterserverIOHTTPHandler::checkAuthentication(Poco::Net:
return {"", true};
}
void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response)
void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, Output & used_output)
{
HTMLForm params(request);
@ -61,24 +60,17 @@ void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & reque
ReadBufferFromIStream body(request.stream());
const auto & config = server.config();
unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
WriteBufferFromHTTPServerResponse out(request, response, keep_alive_timeout);
auto endpoint = server.context().getInterserverIOHandler().getEndpoint(endpoint_name);
if (compress)
{
CompressedWriteBuffer compressed_out(out);
CompressedWriteBuffer compressed_out(*used_output.out.get());
endpoint->processQuery(params, body, compressed_out, response);
}
else
{
endpoint->processQuery(params, body, out, response);
endpoint->processQuery(params, body, *used_output.out.get(), response);
}
out.finalize();
}
@ -90,30 +82,30 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
if (request.getVersion() == Poco::Net::HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
Output used_output;
const auto & config = server.config();
unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(request, response, keep_alive_timeout);
try
{
if (auto [msg, success] = checkAuthentication(request); success)
if (auto [message, success] = checkAuthentication(request); success)
{
processQuery(request, response);
processQuery(request, response, used_output);
LOG_INFO(log, "Done processing query");
}
else
{
response.setStatusAndReason(Poco::Net::HTTPServerResponse::HTTP_UNAUTHORIZED);
if (!response.sent())
response.send() << msg << std::endl;
writeString(message, *used_output.out);
LOG_WARNING(log, "Query processing failed request: '" << request.getURI() << "' authentification failed");
}
}
catch (Exception & e)
{
if (e.code() == ErrorCodes::TOO_MANY_SIMULTANEOUS_QUERIES)
{
if (!response.sent())
response.send();
return;
}
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
@ -122,7 +114,7 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
std::string message = getCurrentExceptionMessage(is_real_error);
if (!response.sent())
response.send() << message << std::endl;
writeString(message, *used_output.out);
if (is_real_error)
LOG_ERROR(log, message);
@ -134,7 +126,8 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
std::string message = getCurrentExceptionMessage(false);
if (!response.sent())
response.send() << message << std::endl;
writeString(message, *used_output.out);
LOG_ERROR(log, message);
}
}

View File

@ -1,12 +1,10 @@
#pragma once
#include <memory>
#include <Poco/Logger.h>
#include <Poco/Net/HTTPRequestHandler.h>
#include <Common/CurrentMetrics.h>
#include "IServer.h"
namespace CurrentMetrics
{
@ -16,6 +14,9 @@ namespace CurrentMetrics
namespace DB
{
class IServer;
class WriteBufferFromHTTPServerResponse;
class InterserverIOHTTPHandler : public Poco::Net::HTTPRequestHandler
{
public:
@ -28,12 +29,17 @@ public:
void handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response) override;
private:
struct Output
{
std::shared_ptr<WriteBufferFromHTTPServerResponse> out;
};
IServer & server;
Poco::Logger * log;
CurrentMetrics::Increment metric_increment{CurrentMetrics::InterserverConnection};
void processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response);
void processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, Output & used_output);
std::pair<String, bool> checkAuthentication(Poco::Net::HTTPServerRequest & request) const;
};

View File

@ -0,0 +1,370 @@
#include <DataStreams/copyData.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <Interpreters/executeQuery.h>
#include <Storages/IStorage.h>
#include <Core/MySQLProtocol.h>
#include <Core/NamesAndTypes.h>
#include <Columns/ColumnVector.h>
#include <Common/config_version.h>
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <Poco/Crypto/RSAKey.h>
#include <Poco/Crypto/CipherFactory.h>
#include <Poco/Net/SecureStreamSocket.h>
#include <Poco/Net/SSLManager.h>
#include "MySQLHandler.h"
#include <limits>
#include <ext/scope_guard.h>
namespace DB
{
using namespace MySQLProtocol;
using Poco::Net::SecureStreamSocket;
using Poco::Net::SSLManager;
namespace ErrorCodes
{
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
extern const int OPENSSL_ERROR;
}
MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, log(&Poco::Logger::get("MySQLHandler"))
, connection_context(server.context())
, connection_id(connection_id)
, public_key(public_key)
, private_key(private_key)
{
server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
if (ssl_enabled)
server_capability_flags |= CLIENT_SSL;
}
void MySQLHandler::run()
{
connection_context = server.context();
connection_context.setDefaultFormat("MySQLWire");
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
try
{
String scramble = generateScramble();
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
*/
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, scramble + '\0');
packet_sender->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake");
HandshakeResponse handshake_response = finishHandshake();
connection_context.client_capabilities = handshake_response.capability_flags;
if (handshake_response.max_packet_size)
connection_context.max_packet_size = handshake_response.max_packet_size;
if (!connection_context.max_packet_size)
connection_context.max_packet_size = MAX_PACKET_LENGTH;
LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags
<< "\nmax_packet_size: "
<< handshake_response.max_packet_size
<< "\ncharacter_set: "
<< handshake_response.character_set
<< "\nuser: "
<< handshake_response.username
<< "\nauth_response length: "
<< handshake_response.auth_response.length()
<< "\nauth_response: "
<< handshake_response.auth_response
<< "\ndatabase: "
<< handshake_response.database
<< "\nauth_plugin_name: "
<< handshake_response.auth_plugin_name);
client_capability_flags = handshake_response.capability_flags;
if (!(client_capability_flags & CLIENT_PROTOCOL_41))
throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
if (!(client_capability_flags & CLIENT_PLUGIN_AUTH))
throw Exception("Required capability: CLIENT_PLUGIN_AUTH.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
authenticate(handshake_response, scramble);
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
packet_sender->sendPacket(ok_packet, true);
while (true)
{
packet_sender->resetSequenceId();
String payload = packet_sender->receivePacketPayload();
int command = payload[0];
LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << ".");
try
{
switch (command)
{
case COM_QUIT:
return;
case COM_INIT_DB:
comInitDB(payload);
break;
case COM_QUERY:
comQuery(payload);
break;
case COM_FIELD_LIST:
comFieldList(payload);
break;
case COM_PING:
comPing();
break;
default:
throw Exception(Poco::format("Command %d is not implemented.", command), ErrorCodes::NOT_IMPLEMENTED);
}
}
catch (const NetException & exc)
{
log->log(exc);
throw;
}
catch (const Exception & exc)
{
log->log(exc);
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
}
}
}
catch (Poco::Exception & exc)
{
log->log(exc);
}
}
/** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest.
* Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake.
* If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
*/
MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
{
HandshakeResponse packet;
size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE;
/// Buffer for SSLRequest or part of HandshakeResponse.
char buf[packet_size];
size_t pos = 0;
/// Reads at least count and at most packet_size bytes.
auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void {
while (pos < count)
{
int ret = socket().receiveBytes(buf + pos, packet_size - pos);
if (ret == 0)
{
throw Exception("Cannot read all data. Bytes read: " + std::to_string(pos) + ". Bytes expected: 3.", ErrorCodes::CANNOT_READ_ALL_DATA);
}
pos += ret;
}
};
read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes.
size_t payload_size = unalignedLoad<uint32_t>(buf) & 0xFFFFFFu;
LOG_TRACE(log, "payload size: " << payload_size);
if (payload_size == SSL_REQUEST_PAYLOAD_SIZE)
{
read_bytes(packet_size); /// Reading rest SSLRequest.
SSLRequest ssl_request;
ssl_request.readPayload(String(buf + PACKET_HEADER_SIZE, pos - PACKET_HEADER_SIZE));
connection_context.client_capabilities = ssl_request.capability_flags;
connection_context.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
secure_connection = true;
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context.sequence_id = 2;
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
packet_sender->max_packet_size = connection_context.max_packet_size;
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
}
else
{
/// Reading rest of HandshakeResponse.
packet_size = PACKET_HEADER_SIZE + payload_size;
WriteBufferFromOwnString buf_for_handshake_response;
buf_for_handshake_response.write(buf, pos);
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
packet.readPayload(buf_for_handshake_response.str().substr(PACKET_HEADER_SIZE));
packet_sender->sequence_id++;
}
return packet;
}
String MySQLHandler::generateScramble()
{
String scramble(MySQLProtocol::SCRAMBLE_LENGTH, 0);
Poco::RandomInputStream generator;
for (size_t i = 0; i < scramble.size(); i++)
{
generator >> scramble[i];
}
return scramble;
}
void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, const String & scramble)
{
String auth_response;
AuthSwitchResponse response;
if (handshake_response.auth_plugin_name != Authentication::SHA256)
{
packet_sender->sendPacket(AuthSwitchRequest(Authentication::SHA256, scramble + '\0'), true);
if (in->eof())
throw Exception(
"Client doesn't support authentication method " + String(Authentication::SHA256) + " used by ClickHouse",
ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
packet_sender->receivePacket(response);
auth_response = response.value;
LOG_TRACE(log, "Authentication method mismatch.");
}
else
{
auth_response = handshake_response.auth_response;
LOG_TRACE(log, "Authentication method match.");
}
if (auth_response == "\1")
{
LOG_TRACE(log, "Client requests public key.");
BIO * mem = BIO_new(BIO_s_mem());
SCOPE_EXIT(BIO_free(mem));
if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1)
{
throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
char * pem_buf = nullptr;
long pem_size = BIO_get_mem_data(mem, &pem_buf);
String pem(pem_buf, pem_size);
LOG_TRACE(log, "Key: " << pem);
AuthMoreData data(pem);
packet_sender->sendPacket(data, true);
packet_sender->receivePacket(response);
auth_response = response.value;
}
else
{
LOG_TRACE(log, "Client didn't request public key.");
}
String password;
/** Decrypt password, if it's not empty.
* The original intention was that the password is a string[NUL] but this never got enforced properly so now we have to accept that
* an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
*/
if (!secure_connection && !auth_response.empty() && auth_response != String("\0", 1))
{
LOG_TRACE(log, "Received nonempty password");
auto ciphertext = reinterpret_cast<unsigned char *>(auth_response.data());
unsigned char plaintext[RSA_size(&private_key)];
int plaintext_size = RSA_private_decrypt(auth_response.size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING);
if (plaintext_size == -1)
{
throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
password.resize(plaintext_size);
for (int i = 0; i < plaintext_size; i++)
{
password[i] = plaintext[i] ^ static_cast<unsigned char>(scramble[i % scramble.size()]);
}
}
else if (secure_connection)
{
password = auth_response;
}
else
{
LOG_TRACE(log, "Received empty password");
}
if (!password.empty())
{
password.pop_back(); /// terminating null byte
}
try
{
connection_context.setUser(handshake_response.username, password, socket().address(), "");
connection_context.setCurrentDatabase(handshake_response.database);
connection_context.setCurrentQueryId("");
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded.");
}
catch (const Exception & exc)
{
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " failed.");
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
throw;
}
}
void MySQLHandler::comInitDB(const String & payload)
{
String database = payload.substr(1);
LOG_DEBUG(log, "Setting current database to " << database);
connection_context.setCurrentDatabase(database);
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
}
void MySQLHandler::comFieldList(const String & payload)
{
ComFieldList packet;
packet.readPayload(payload);
String database = connection_context.getCurrentDatabase();
StoragePtr tablePtr = connection_context.getTable(database, packet.table);
for (const NameAndTypePair & column: tablePtr->getColumns().getAll())
{
ColumnDefinition column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
);
packet_sender->sendPacket(column_definition);
}
packet_sender->sendPacket(OK_Packet(0xfe, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comPing()
{
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comQuery(const String & payload)
{
bool with_output = false;
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
with_output = true;
};
String query = payload.substr(1);
// Translate query from MySQL to ClickHouse.
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
query = "select ''";
ReadBufferFromString buf(query);
executeQuery(buf, *out, true, connection_context, set_content_type, nullptr);
if (!with_output)
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
}
}

View File

@ -0,0 +1,59 @@
#pragma once
#include <Poco/Net/TCPServerConnection.h>
#include <Poco/Net/SecureStreamSocket.h>
#include <Common/getFQDNOrHostName.h>
#include <Core/MySQLProtocol.h>
#include <openssl/rsa.h>
#include "IServer.h"
namespace DB
{
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
class MySQLHandler : public Poco::Net::TCPServerConnection
{
public:
MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id);
void run() final;
private:
/// Enables SSL, if client requested.
MySQLProtocol::HandshakeResponse finishHandshake();
void comQuery(const String & payload);
void comFieldList(const String & payload);
void comPing();
void comInitDB(const String & payload);
static String generateScramble();
void authenticate(const MySQLProtocol::HandshakeResponse &, const String & scramble);
IServer & server;
Poco::Logger * log;
Context connection_context;
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
size_t connection_id = 0;
size_t server_capability_flags;
size_t client_capability_flags;
RSA & public_key;
RSA & private_key;
std::shared_ptr<ReadBuffer> in;
std::shared_ptr<WriteBuffer> out;
bool secure_connection = false;
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
};
}

View File

@ -0,0 +1,124 @@
#include <Common/OpenSSLHelpers.h>
#include <Poco/Crypto/X509Certificate.h>
#include <Poco/Net/SSLManager.h>
#include <Poco/Net/TCPServerConnectionFactory.h>
#include <Poco/Util/Application.h>
#include <common/logger_useful.h>
#include <ext/scope_guard.h>
#include "IServer.h"
#include "MySQLHandler.h"
#include "MySQLHandlerFactory.h"
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_OPEN_FILE;
extern const int NO_ELEMENTS_IN_CONFIG;
extern const int OPENSSL_ERROR;
extern const int SYSTEM_ERROR;
}
MySQLHandlerFactory::MySQLHandlerFactory(IServer & server_)
: server(server_)
, log(&Logger::get("MySQLHandlerFactory"))
{
try
{
Poco::Net::SSLManager::instance().defaultServerContext();
}
catch (...)
{
LOG_INFO(log, "Failed to create SSL context. SSL will be disabled. Error: " << getCurrentExceptionMessage(false));
ssl_enabled = false;
}
/// Reading rsa keys for SHA256 authentication plugin.
try
{
readRSAKeys();
}
catch (...)
{
LOG_WARNING(log, "Failed to read RSA keys. Error: " << getCurrentExceptionMessage(false));
generateRSAKeys();
}
}
void MySQLHandlerFactory::readRSAKeys()
{
const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config();
String certificateFileProperty = "openSSL.server.certificateFile";
String privateKeyFileProperty = "openSSL.server.privateKeyFile";
if (!config.has(certificateFileProperty))
throw Exception("Certificate file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
if (!config.has(privateKeyFileProperty))
throw Exception("Private key file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
{
String certificateFile = config.getString(certificateFileProperty);
FILE * fp = fopen(certificateFile.data(), "r");
if (fp == nullptr)
throw Exception("Cannot open certificate file: " + certificateFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
SCOPE_EXIT(fclose(fp));
X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr);
SCOPE_EXIT(X509_free(x509));
if (x509 == nullptr)
throw Exception("Failed to read PEM certificate from " + certificateFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
EVP_PKEY * p = X509_get_pubkey(x509);
if (p == nullptr)
throw Exception("Failed to get RSA key from X509. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
SCOPE_EXIT(EVP_PKEY_free(p));
public_key.reset(EVP_PKEY_get1_RSA(p));
if (public_key.get() == nullptr)
throw Exception("Failed to get RSA key from ENV_PKEY. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
{
String privateKeyFile = config.getString(privateKeyFileProperty);
FILE * fp = fopen(privateKeyFile.data(), "r");
if (fp == nullptr)
throw Exception ("Cannot open private key file " + privateKeyFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
SCOPE_EXIT(fclose(fp));
private_key.reset(PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr));
if (!private_key)
throw Exception("Failed to read RSA private key from " + privateKeyFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
}
void MySQLHandlerFactory::generateRSAKeys()
{
LOG_INFO(log, "Generating new RSA key.");
public_key.reset(RSA_new());
if (!public_key)
throw Exception("Failed to allocate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
BIGNUM * e = BN_new();
if (!e)
throw Exception("Failed to allocate BIGNUM. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
SCOPE_EXIT(BN_free(e));
if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(public_key.get(), 2048, e, nullptr))
throw Exception("Failed to generate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
private_key.reset(RSAPrivateKey_dup(public_key.get()));
if (!private_key)
throw Exception("Failed to copy RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
Poco::Net::TCPServerConnection * MySQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket)
{
size_t connection_id = last_connection_id++;
LOG_TRACE(log, "MySQL connection. Id: " << connection_id << ". Address: " << socket.peerAddress().toString());
return new MySQLHandler(server, socket, *public_key, *private_key, ssl_enabled, connection_id);
}
}

View File

@ -0,0 +1,39 @@
#pragma once
#include <Poco/Net/TCPServerConnectionFactory.h>
#include <atomic>
#include <openssl/rsa.h>
#include "IServer.h"
namespace DB
{
class MySQLHandlerFactory : public Poco::Net::TCPServerConnectionFactory
{
private:
IServer & server;
Poco::Logger * log;
struct RSADeleter
{
void operator()(RSA * ptr) { RSA_free(ptr); }
};
using RSAPtr = std::unique_ptr<RSA, RSADeleter>;
RSAPtr public_key;
RSAPtr private_key;
bool ssl_enabled = true;
std::atomic<size_t> last_connection_id = 0;
public:
explicit MySQLHandlerFactory(IServer & server_);
void readRSAKeys();
void generateRSAKeys();
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override;
};
}

View File

@ -49,6 +49,7 @@
#include <Common/StatusFile.h>
#include "TCPHandlerFactory.h"
#include "Common/config_version.h"
#include "MySQLHandlerFactory.h"
#if defined(__linux__)
#include <Common/hasLinuxCapability.h>
@ -668,7 +669,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
socket,
new Poco::Net::TCPServerParams));
LOG_INFO(log, "Listening tcp: " + address.toString());
LOG_INFO(log, "Listening for connections with native protocol (tcp): " + address.toString());
}
/// TCP with SSL
@ -685,7 +686,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
server_pool,
socket,
new Poco::Net::TCPServerParams));
LOG_INFO(log, "Listening tcp_secure: " + address.toString());
LOG_INFO(log, "Listening for connections with secure native protocol (tcp_secure): " + address.toString());
#else
throw Exception{"SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.",
ErrorCodes::SUPPORT_IS_DISABLED};
@ -710,7 +711,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
socket,
http_params));
LOG_INFO(log, "Listening interserver http: " + address.toString());
LOG_INFO(log, "Listening for replica communication (interserver) http://" + address.toString());
}
if (config().has("interserver_https_port"))
@ -727,12 +728,27 @@ int Server::main(const std::vector<std::string> & /*args*/)
socket,
http_params));
LOG_INFO(log, "Listening interserver https: " + address.toString());
LOG_INFO(log, "Listening for secure replica communication (interserver) https://" + address.toString());
#else
throw Exception{"SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.",
ErrorCodes::SUPPORT_IS_DISABLED};
#endif
}
if (config().has("mysql_port"))
{
Poco::Net::ServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("mysql_port"), /* secure = */ true);
socket.setReceiveTimeout(Poco::Timespan());
socket.setSendTimeout(settings.send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
new MySQLHandlerFactory(*this),
server_pool,
socket,
new Poco::Net::TCPServerParams));
LOG_INFO(log, "Listening for MySQL compatibility protocol: " + address.toString());
}
}
catch (const Poco::Exception & e)
{

View File

@ -1,2 +1,5 @@
*.bin
*.mrk
*.txt
*.dat
*.idx

View File

@ -308,16 +308,88 @@ public:
/**
* Check whether two bitmaps intersect.
* Intersection with an empty set is always 0 (consistent with hasAny).
*/
UInt8 rb_intersect(const RoaringBitmapWithSmallSet & r1)
UInt8 rb_intersect(const RoaringBitmapWithSmallSet & r1) const
{
if (isSmall())
toLarge();
roaring_bitmap_t * rb1 = r1.isSmall() ? r1.getNewRbFromSmall() : r1.getRb();
UInt8 is_true = roaring_bitmap_intersect(rb, rb1);
if (r1.isSmall())
roaring_bitmap_free(rb1);
return is_true;
{
if (r1.isSmall())
{
for (const auto & x : r1.small)
if (small.find(x.getValue()) != small.end())
return 1;
}
else
{
for (const auto & x : small)
if (roaring_bitmap_contains(r1.rb, x.getValue()))
return 1;
}
}
else if (r1.isSmall())
{
for (const auto & x : r1.small)
if (roaring_bitmap_contains(rb, x.getValue()))
return 1;
}
else if (roaring_bitmap_intersect(rb, r1.rb))
return 1;
return 0;
}
/**
* Check whether the argument is the subset of this set.
* Empty set is a subset of any other set (consistent with hasAll).
*/
UInt8 rb_is_subset(const RoaringBitmapWithSmallSet & r1) const
{
if (isSmall())
{
if (r1.isSmall())
{
for (const auto & x : r1.small)
if (small.find(x.getValue()) == small.end())
return 0;
}
else
{
UInt64 r1_size = r1.size();
if (r1_size > small.size())
return 0; // A bigger set can not be a subset of ours.
// This is a rare case with a small number of elements on
// both sides: r1 was promoted to large for some reason and
// it is still not larger than our small set.
// If r1 is our subset then our size must be equal to
// r1_size + number of not found elements, if this sum becomes
// greater then r1 is not a subset.
for (const auto & x : small)
if (!roaring_bitmap_contains(r1.rb, x.getValue()) && ++r1_size > small.size())
return 0;
}
}
else if (r1.isSmall())
{
for (const auto & x : r1.small)
if (!roaring_bitmap_contains(rb, x.getValue()))
return 0;
}
else if (!roaring_bitmap_is_subset(r1.rb, rb))
return 0;
return 1;
}
/**
* Check whether this bitmap contains the argument.
*/
UInt8 rb_contains(const UInt32 x) const
{
return isSmall() ? small.find(x) != small.end() :
roaring_bitmap_contains(rb, x);
}
/**

View File

@ -3,6 +3,8 @@
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h>
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include "AggregateFunctionFactory.h"
@ -34,7 +36,7 @@ namespace
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!isNumber(argument_types[i]))
if (!isNativeNumber(argument_types[i]))
throw Exception(
"Argument " + std::to_string(i) + " of type " + argument_types[i]->getName()
+ " must be numeric for aggregate function " + name,
@ -110,8 +112,8 @@ namespace
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
factory.registerFunction("linearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("logisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
}
LinearModelData::LinearModelData(
@ -149,6 +151,27 @@ void LinearModelData::predict(
gradient_computer->predict(container, block, arguments, weights, bias, context);
}
void LinearModelData::returnWeights(IColumn & to) const
{
size_t size = weights.size() + 1;
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
size_t old_size = offsets_to.back();
offsets_to.push_back(old_size + size);
typename ColumnFloat64::Container & val_to
= static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
val_to.reserve(old_size + size);
for (size_t i = 0; i + 1 < size; ++i)
val_to.push_back(weights[i]);
val_to.push_back(bias);
}
void LinearModelData::read(ReadBuffer & buf)
{
readBinary(bias, buf);
@ -345,7 +368,7 @@ void LogisticRegression::predict(
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type))
if (!isNativeNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
@ -418,7 +441,7 @@ void LinearRegression::predict(
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type))
if (!isNativeNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}

View File

@ -4,6 +4,8 @@
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
#include "IAggregateFunction.h"
namespace DB
@ -218,6 +220,7 @@ public:
void
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
void returnWeights(IColumn & to) const;
private:
std::vector<Float64> weights;
Float64 bias{0.0};
@ -269,7 +272,15 @@ public:
{
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
}
DataTypePtr getReturnTypeToPredict() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void create(AggregateDataPtr place) const override
{
@ -301,11 +312,12 @@ public:
this->data(place).predict(column.getData(), block, arguments, context);
}
/** This function is called if aggregate function without State modifier is selected in a query.
* Inserts all weights of the model into the column 'to', so user may use such information if needed
*/
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
std::ignore = place;
std::ignore = to;
throw std::runtime_error("not implemented");
this->data(place).returnWeights(to);
}
const char * getHeaderFilePath() const override { return __FILE__; }
@ -321,10 +333,10 @@ private:
struct NameLinearRegression
{
static constexpr auto name = "LinearRegression";
static constexpr auto name = "linearRegression";
};
struct NameLogisticRegression
{
static constexpr auto name = "LogisticRegression";
static constexpr auto name = "logisticRegression";
};
}

View File

@ -61,10 +61,10 @@ public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
{
if (!isNumber(arguments[0]))
if (!isNativeNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!isNumber(arguments[1]))
if (!isNativeNumber(arguments[1]))
throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[0]->equals(*arguments[1]))

View File

@ -1,6 +1,12 @@
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <ext/range.h>
namespace DB
{
@ -12,32 +18,58 @@ namespace ErrorCodes
namespace
{
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & params)
template <template <typename, typename> class AggregateFunction, template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionSequenceBase(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
}
const auto arg_count = argument_types.size();
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
if (arg_count < 3)
throw Exception{"Aggregate function " + name + " requires at least 3 arguments.",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > max_events)
throw Exception{"Aggregate function " + name + " supports up to "
+ toString(max_events) + " event arguments.",
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = argument_types.front().get();
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = argument_types[i].get();
if (!isUInt8(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1)
+ " of aggregate function " + name + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
if (res)
return res;
WhichDataType which(argument_types.front().get());
if (which.isDateTime())
return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
else if (which.isDate())
return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ name + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
}
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
{
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceMatch);
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceCount);
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
}
}

View File

@ -36,11 +36,12 @@ struct ComparePairFirst final
}
};
static constexpr auto max_events = 32;
template <typename T>
struct AggregateFunctionSequenceMatchData final
{
static constexpr auto max_events = 32;
using Timestamp = std::uint32_t;
using Timestamp = T;
using Events = std::bitset<max_events>;
using TimestampEvents = std::pair<Timestamp, Events>;
using Comparator = ComparePairFirst<std::less>;
@ -61,6 +62,9 @@ struct AggregateFunctionSequenceMatchData final
void merge(const AggregateFunctionSequenceMatchData & other)
{
if (other.events_list.empty())
return;
const auto size = events_list.size();
events_list.insert(std::begin(other.events_list), std::end(other.events_list));
@ -119,7 +123,7 @@ struct AggregateFunctionSequenceMatchData final
for (size_t i = 0; i < size; ++i)
{
std::uint32_t timestamp;
Timestamp timestamp;
readBinary(timestamp, buf);
UInt64 events;
@ -135,48 +139,23 @@ struct AggregateFunctionSequenceMatchData final
constexpr auto sequence_match_max_iterations = 1000000;
template <typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
template <typename T, typename Data, typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
, pattern(pattern)
{
arg_count = arguments.size();
if (!sufficientArgs(arg_count))
throw Exception{"Aggregate function " + derived().getName() + " requires at least 3 arguments.",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > AggregateFunctionSequenceMatchData::max_events)
throw Exception{"Aggregate function " + derived().getName() + " supports up to " +
toString(AggregateFunctionSequenceMatchData::max_events) + " event arguments.",
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = arguments.front().get();
if (!WhichDataType(time_arg).isDateTime())
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ derived().getName() + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = arguments[i].get();
if (!isUInt8(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
" of aggregate function " + derived().getName() + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
parsePattern();
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
const auto timestamp = static_cast<const ColumnUInt32 *>(columns[0])->getData()[row_num];
const auto timestamp = static_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
AggregateFunctionSequenceMatchData::Events events;
typename Data::Events events;
for (const auto i : ext::range(1, arg_count))
{
const auto event = static_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
@ -218,17 +197,15 @@ private:
struct PatternAction final
{
PatternActionType type;
std::uint32_t extra;
std::uint64_t extra;
PatternAction() = default;
PatternAction(const PatternActionType type, const std::uint32_t extra = 0) : type{type}, extra{extra} {}
PatternAction(const PatternActionType type, const std::uint64_t extra = 0) : type{type}, extra{extra} {}
};
static constexpr size_t bytes_on_stack = 64;
using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
Derived & derived() { return static_cast<Derived &>(*this); }
void parsePattern()
@ -340,8 +317,8 @@ protected:
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
/// of events) with a memory consumption and memory allocations in O(m). It means that
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
template <typename T>
bool dfaMatch(T & events_it, const T events_end) const
template <typename EventEntry>
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
{
using ActiveStates = std::vector<bool>;
@ -396,8 +373,8 @@ protected:
return active_states.back();
}
template <typename T>
bool backtrackingMatch(T & events_it, const T events_end) const
template <typename EventEntry>
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
{
const auto action_begin = std::begin(actions);
const auto action_end = std::end(actions);
@ -407,7 +384,7 @@ protected:
auto base_it = events_it;
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
using backtrack_info = std::tuple<decltype(action_it), T, T>;
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
std::stack<backtrack_info> back_stack;
/// backtrack if possible
@ -458,7 +435,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeLessOrEqual)
{
if (events_it->first - base_it->first <= action_it->extra)
if (events_it->first <= base_it->first + action_it->extra)
{
/// condition satisfied, move onto next action
back_stack.emplace(action_it, events_it, base_it);
@ -470,7 +447,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeLess)
{
if (events_it->first - base_it->first < action_it->extra)
if (events_it->first < base_it->first + action_it->extra)
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -481,7 +458,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
{
if (events_it->first - base_it->first >= action_it->extra)
if (events_it->first >= base_it->first + action_it->extra)
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -492,7 +469,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeGreater)
{
if (events_it->first - base_it->first > action_it->extra)
if (events_it->first > base_it->first + action_it->extra)
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -575,14 +552,14 @@ private:
DFAStates dfa_states;
};
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
template <typename T, typename Data>
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
{
public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
@ -590,27 +567,27 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
const_cast<Data &>(this->data(place)).sort();
const auto & data_ref = data(place);
const auto & data_ref = this->data(place);
const auto events_begin = std::begin(data_ref.events_list);
const auto events_end = std::end(data_ref.events_list);
auto events_it = events_begin;
bool match = pattern_has_time ? backtrackingMatch(events_it, events_end) : dfaMatch(events_it, events_end);
bool match = this->pattern_has_time ? this->backtrackingMatch(events_it, events_end) : this->dfaMatch(events_it, events_end);
static_cast<ColumnUInt8 &>(to).getData().push_back(match);
}
};
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
template <typename T, typename Data>
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
{
public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }
@ -618,21 +595,21 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
const_cast<Data &>(this->data(place)).sort();
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
}
private:
UInt64 count(const ConstAggregateDataPtr & place) const
{
const auto & data_ref = data(place);
const auto & data_ref = this->data(place);
const auto events_begin = std::begin(data_ref.events_list);
const auto events_end = std::end(data_ref.events_list);
auto events_it = events_begin;
size_t count = 0;
while (events_it != events_end && backtrackingMatch(events_it, events_end))
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
++count;
return count;

View File

@ -1,8 +1,9 @@
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Core/TypeListNumber.h>
namespace DB
{
@ -10,7 +11,7 @@ namespace DB
namespace
{
AggregateFunctionPtr createAggregateFunctionLeastSqr(
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
const String & name,
const DataTypes & arguments,
const Array & params
@ -20,16 +21,11 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
assertBinary(name, arguments);
const IDataType * x_arg = arguments.front().get();
WhichDataType which_x {
x_arg
};
WhichDataType which_x = x_arg;
const IDataType * y_arg = arguments.back().get();
WhichDataType which_y = y_arg;
WhichDataType which_y {
y_arg
};
#define FOR_LEASTSQR_TYPES_2(M, T) \
M(T, UInt8) \
@ -55,7 +51,7 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
FOR_LEASTSQR_TYPES_2(M, Float64)
#define DISPATCH(T1, T2) \
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
return std::make_shared<AggregateFunctionLeastSqr<T1, T2>>( \
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \
arguments, \
params \
);
@ -77,9 +73,9 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
}
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
{
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
}
}

View File

@ -19,7 +19,7 @@ namespace ErrorCodes
}
template <typename X, typename Y, typename Ret>
struct AggregateFunctionLeastSqrData final
struct AggregateFunctionSimpleLinearRegressionData final
{
size_t count = 0;
Ret sum_x = 0;
@ -36,7 +36,7 @@ struct AggregateFunctionLeastSqrData final
sum_xy += x * y;
}
void merge(const AggregateFunctionLeastSqrData & other)
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
{
count += other.count;
sum_x += other.sum_x;
@ -85,19 +85,19 @@ struct AggregateFunctionLeastSqrData final
/// Calculates simple linear regression parameters.
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
>
{
public:
AggregateFunctionLeastSqr(
AggregateFunctionSimpleLinearRegression(
const DataTypes & arguments,
const Array & params
):
IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params}
{
// notice: arguments has been checked before
@ -105,7 +105,7 @@ public:
String getName() const override
{
return "leastSqr";
return "simpleLinearRegression";
}
const char * getHeaderFilePath() const override
@ -120,12 +120,8 @@ public:
Arena *
) const override
{
auto col_x {
static_cast<const ColumnVector<X> *>(columns[0])
};
auto col_y {
static_cast<const ColumnVector<Y> *>(columns[1])
};
auto col_x = static_cast<const ColumnVector<X> *>(columns[0]);
auto col_y = static_cast<const ColumnVector<Y> *>(columns[1]);
X x = col_x->getData()[row_num];
Y y = col_y->getData()[row_num];
@ -159,12 +155,14 @@ public:
DataTypePtr getReturnType() const override
{
DataTypes types {
DataTypes types
{
std::make_shared<DataTypeNumber<Ret>>(),
std::make_shared<DataTypeNumber<Ret>>(),
};
Strings names {
Strings names
{
"k",
"b",
};

View File

@ -56,6 +56,10 @@ void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & facto
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
factory.registerFunction("skewSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionSkewSampSimple>);
factory.registerFunction("skewPop", createAggregateFunctionStatisticsUnary<AggregateFunctionSkewPopSimple>);
factory.registerFunction("kurtSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionKurtSampSimple>);
factory.registerFunction("kurtPop", createAggregateFunctionStatisticsUnary<AggregateFunctionKurtPopSimple>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);

View File

@ -32,30 +32,42 @@ namespace DB
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int DECIMAL_OVERFLOW;
}
template <typename T>
/**
Calculating univariate central moments
Levels:
level 2 (pop & samp): var, stddev
level 3: skewness
level 4: kurtosis
References:
https://en.wikipedia.org/wiki/Moment_(mathematics)
https://en.wikipedia.org/wiki/Skewness
https://en.wikipedia.org/wiki/Kurtosis
*/
template <typename T, size_t _level>
struct VarMoments
{
T m0{};
T m1{};
T m2{};
T m[_level + 1]{};
void add(T x)
{
++m0;
m1 += x;
m2 += x * x;
++m[0];
m[1] += x;
m[2] += x * x;
if constexpr (_level >= 3) m[3] += x * x * x;
if constexpr (_level >= 4) m[4] += x * x * x * x;
}
void merge(const VarMoments & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
m2 += rhs.m2;
m[0] += rhs.m[0];
m[1] += rhs.m[1];
m[2] += rhs.m[2];
if constexpr (_level >= 3) m[3] += rhs.m[3];
if constexpr (_level >= 4) m[4] += rhs.m[4];
}
void write(WriteBuffer & buf) const
@ -70,45 +82,90 @@ struct VarMoments
T NO_SANITIZE_UNDEFINED getPopulation() const
{
return (m2 - m1 * m1 / m0) / m0;
return (m[2] - m[1] * m[1] / m[0]) / m[0];
}
T NO_SANITIZE_UNDEFINED getSample() const
{
if (m0 == 0)
if (m[0] == 0)
return std::numeric_limits<T>::quiet_NaN();
return (m2 - m1 * m1 / m0) / (m0 - 1);
return (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T NO_SANITIZE_UNDEFINED getMoment3() const
{
// to avoid accuracy problem
if (m[0] == 1)
return 0;
return (m[3]
- (3 * m[2]
- 2 * m[1] * m[1] / m[0]
) * m[1] / m[0]
) / m[0];
}
T NO_SANITIZE_UNDEFINED getMoment4() const
{
// to avoid accuracy problem
if (m[0] == 1)
return 0;
return (m[4]
- (4 * m[3]
- (6 * m[2]
- 3 * m[1] * m[1] / m[0]
) * m[1] / m[0]
) * m[1] / m[0]
) / m[0];
}
};
template <typename T>
template <typename T, size_t _level>
struct VarMomentsDecimal
{
using NativeType = typename T::NativeType;
UInt64 m0{};
NativeType m1{};
NativeType m2{};
NativeType m[_level]{};
NativeType & getM(size_t i)
{
return m[i - 1];
}
const NativeType & getM(size_t i) const
{
return m[i - 1];
}
void add(NativeType x)
{
++m0;
m1 += x;
getM(1) += x;
NativeType tmp; /// scale' = 2 * scale
if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2))
NativeType tmp;
if (common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 3)
if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 4)
if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void merge(const VarMomentsDecimal & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
getM(1) += rhs.getM(1);
if (common::addOverflow(m2, rhs.m2, m2))
if (common::addOverflow(getM(2), rhs.getM(2), getM(2)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 3)
if (common::addOverflow(getM(3), rhs.getM(3), getM(3)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 4)
if (common::addOverflow(getM(4), rhs.getM(4), getM(4)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
@ -120,8 +177,8 @@ struct VarMomentsDecimal
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp / m0), tmp))
if (common::mulOverflow(getM(1), getM(1), tmp) ||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
@ -134,15 +191,50 @@ struct VarMomentsDecimal
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp / m0), tmp))
if (common::mulOverflow(getM(1), getM(1), tmp) ||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
}
Float64 get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
Float64 getMoment3(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(2 * getM(1), getM(1), tmp) ||
common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) ||
common::mulOverflow(tmp, getM(1), tmp) ||
common::subOverflow(getM(3), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
Float64 getMoment4(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(3 * getM(1), getM(1), tmp) ||
common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) ||
common::mulOverflow(tmp, getM(1), tmp) ||
common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) ||
common::mulOverflow(tmp, getM(1), tmp) ||
common::subOverflow(getM(4), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
};
/**
Calculating multivariate central moments
Levels:
level 2 (pop & samp): covar
References:
https://en.wikipedia.org/wiki/Moment_(mathematics)
*/
template <typename T>
struct CovarMoments
{
@ -188,8 +280,6 @@ struct CovarMoments
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
@ -236,9 +326,6 @@ struct CorrMoments
{
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
}
T getPopulation() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T getSample() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
@ -246,18 +333,20 @@ enum class StatisticsFunctionKind
{
varPop, varSamp,
stddevPop, stddevSamp,
skewPop, skewSamp,
kurtPop, kurtSamp,
covarPop, covarSamp,
corr
};
template <typename T, StatisticsFunctionKind _kind>
template <typename T, StatisticsFunctionKind _kind, size_t _level>
struct StatFuncOneArg
{
using Type1 = T;
using Type2 = T;
using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128>, VarMoments<ResultType>>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128, _level>, VarMoments<ResultType, _level>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 1;
@ -300,17 +389,28 @@ public:
String getName() const override
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: return "varPop";
case StatisticsFunctionKind::varSamp: return "varSamp";
case StatisticsFunctionKind::stddevPop: return "stddevPop";
case StatisticsFunctionKind::stddevSamp: return "stddevSamp";
case StatisticsFunctionKind::covarPop: return "covarPop";
case StatisticsFunctionKind::covarSamp: return "covarSamp";
case StatisticsFunctionKind::corr: return "corr";
}
__builtin_unreachable();
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
return "varPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
return "varSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
return "stddevPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
return "stddevSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
return "skewPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
return "skewSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
return "kurtPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
return "kurtSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
return "covarPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
return "covarSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
return "corr";
}
DataTypePtr getReturnType() const override
@ -351,28 +451,103 @@ public:
if constexpr (IsDecimalNumber<T1>)
{
switch (StatFunc::kind)
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
dst.push_back(data.getPopulation(src_scale * 2));
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
dst.push_back(data.getSample(src_scale * 2));
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
dst.push_back(sqrt(data.getPopulation(src_scale * 2)));
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
dst.push_back(sqrt(data.getSample(src_scale * 2)));
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation(src_scale * 2)); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample(src_scale * 2)); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation(src_scale * 2))); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample(src_scale * 2))); break;
default:
__builtin_unreachable();
Float64 var_value = data.getPopulation(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
{
Float64 var_value = data.getSample(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
{
Float64 var_value = data.getPopulation(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
{
Float64 var_value = data.getSample(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
}
else
{
switch (StatFunc::kind)
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
dst.push_back(data.getPopulation());
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
dst.push_back(data.getSample());
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
dst.push_back(sqrt(data.getPopulation()));
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
dst.push_back(sqrt(data.getSample()));
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation())); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample())); break;
case StatisticsFunctionKind::covarPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::covarSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::corr: dst.push_back(data.get()); break;
ResultType var_value = data.getPopulation();
if (var_value > 0)
dst.push_back(data.getMoment3() / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
{
ResultType var_value = data.getSample();
if (var_value > 0)
dst.push_back(data.getMoment3() / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
{
ResultType var_value = data.getPopulation();
if (var_value > 0)
dst.push_back(data.getMoment4() / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
{
ResultType var_value = data.getSample();
if (var_value > 0)
dst.push_back(data.getMoment4() / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
dst.push_back(data.getPopulation());
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
dst.push_back(data.getSample());
if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
dst.push_back(data.get());
}
}
@ -383,10 +558,14 @@ private:
};
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp>>;
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop, 2>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp, 2>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop, 2>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp, 2>>;
template <typename T> using AggregateFunctionSkewPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewPop, 3>>;
template <typename T> using AggregateFunctionSkewSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewSamp, 3>>;
template <typename T> using AggregateFunctionKurtPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtPop, 4>>;
template <typename T> using AggregateFunctionKurtSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtSamp, 4>>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;

View File

@ -48,6 +48,12 @@ public:
/// Get the result type.
virtual DataTypePtr getReturnType() const = 0;
/// Get type which will be used for prediction result in case if function is an ML method.
virtual DataTypePtr getReturnTypeToPredict() const
{
throw Exception("Prediction is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
virtual ~IAggregateFunction() {}
/** Data manipulating functions. */

View File

@ -30,7 +30,7 @@ void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -73,7 +73,7 @@ void registerAggregateFunctions()
registerAggregateFunctionTimeSeriesGroupSum(factory);
registerAggregateFunctionMLMethod(factory);
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionLeastSqr(factory);
registerAggregateFunctionSimpleLinearRegression(factory);
}
{

View File

@ -35,25 +35,6 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
/// This function is used in convertToValues() and predictValues()
/// and is written here to avoid repetitions
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
{
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
*res_ = std::move(res);
return true;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
*res_ = std::move(res);
return false;
}
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
/** If the aggregate function returns an unfinalized/unfinished state,
@ -86,17 +67,17 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
/** Convertion function is used in convertToValues and predictValues
* in the similar part of both functions
*/
MutableColumnPtr res;
if (tryFinalizeAggregateFunction(&res))
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
for (auto val : data)
func->insertResultInto(val, *res);
@ -105,8 +86,8 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
{
MutableColumnPtr res;
tryFinalizeAggregateFunction(&res);
MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn();
res->reserve(data.size());
auto ML_function = func.get();
if (ML_function)

View File

@ -427,6 +427,8 @@ namespace ErrorCodes
extern const int BAD_TTL_EXPRESSION = 450;
extern const int BAD_TTL_FILE = 451;
extern const int SETTING_CONSTRAINT_VIOLATION = 452;
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES = 453;
extern const int OPENSSL_ERROR = 454;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -0,0 +1,18 @@
#include "OpenSSLHelpers.h"
#include <ext/scope_guard.h>
#include <openssl/err.h>
namespace DB
{
String getOpenSSLErrors()
{
BIO * mem = BIO_new(BIO_s_mem());
SCOPE_EXIT(BIO_free(mem));
ERR_print_errors(mem);
char * buf = nullptr;
long size = BIO_get_mem_data(mem, &buf);
return String(buf, size);
}
}

View File

@ -0,0 +1,12 @@
#pragma once
#include <Core/Types.h>
namespace DB
{
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
String getOpenSSLErrors();
}

View File

@ -56,6 +56,8 @@
#define DBMS_MIN_REVISION_WITH_LOW_CARDINALITY_TYPE 54405
#define DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO 54421
/// Version of ClickHouse TCP protocol. Set to git tag with latest protocol change.
#define DBMS_TCP_PROTOCOL_VERSION 54226

View File

@ -0,0 +1,94 @@
#include <IO/WriteBuffer.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
#include <common/logger_useful.h>
#include <random>
#include <sstream>
#include "MySQLProtocol.h"
namespace DB
{
namespace MySQLProtocol
{
void PacketSender::resetSequenceId()
{
sequence_id = 0;
}
String PacketSender::packetToText(String payload)
{
String result;
for (auto c : payload)
{
result += ' ';
result += std::to_string(static_cast<unsigned char>(c));
}
return result;
}
uint64_t readLengthEncodedNumber(std::istringstream & ss)
{
char c;
uint64_t buf = 0;
ss.get(c);
auto cc = static_cast<uint8_t>(c);
if (cc < 0xfc)
{
return cc;
}
else if (cc < 0xfd)
{
ss.read(reinterpret_cast<char *>(&buf), 2);
}
else if (cc < 0xfe)
{
ss.read(reinterpret_cast<char *>(&buf), 3);
}
else
{
ss.read(reinterpret_cast<char *>(&buf), 8);
}
return buf;
}
std::string writeLengthEncodedNumber(uint64_t x)
{
std::string result;
if (x < 251)
{
result.append(1, static_cast<char>(x));
}
else if (x < (1 << 16))
{
result.append(1, 0xfc);
result.append(reinterpret_cast<char *>(&x), 2);
}
else if (x < (1 << 24))
{
result.append(1, 0xfd);
result.append(reinterpret_cast<char *>(&x), 3);
}
else
{
result.append(1, 0xfe);
result.append(reinterpret_cast<char *>(&x), 8);
}
return result;
}
void writeLengthEncodedString(std::string & payload, const std::string & s)
{
payload.append(writeLengthEncodedNumber(s.length()));
payload.append(s);
}
void writeNulTerminatedString(std::string & payload, const std::string & s)
{
payload.append(s);
payload.append(1, 0);
}
}
}

View File

@ -0,0 +1,649 @@
#pragma once
#include <Core/Types.h>
#include <IO/copyData.h>
#include <IO/ReadBuffer.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <IO/WriteBufferFromString.h>
#include <Poco/Net/StreamSocket.h>
#include <Poco/RandomStream.h>
#include <random>
#include <sstream>
/// Implementation of MySQL wire protocol
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_PACKET_FROM_CLIENT;
}
namespace MySQLProtocol
{
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
const size_t SCRAMBLE_LENGTH = 20;
const size_t AUTH_PLUGIN_DATA_PART_1_LENGTH = 8;
const size_t MYSQL_ERRMSG_SIZE = 512;
const size_t PACKET_HEADER_SIZE = 4;
const size_t SSL_REQUEST_PAYLOAD_SIZE = 32;
namespace Authentication
{
const String SHA256 = "sha256_password"; /// Caching SHA2 plugin is not used because it would be possible to authenticate knowing hash from users.xml.
}
enum CharacterSet
{
utf8_general_ci = 33,
binary = 63
};
enum StatusFlags
{
SERVER_SESSION_STATE_CHANGED = 0x4000
};
enum Capability
{
CLIENT_CONNECT_WITH_DB = 0x00000008,
CLIENT_PROTOCOL_41 = 0x00000200,
CLIENT_SSL = 0x00000800,
CLIENT_TRANSACTIONS = 0x00002000, // TODO
CLIENT_SESSION_TRACK = 0x00800000, // TODO
CLIENT_SECURE_CONNECTION = 0x00008000,
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000,
CLIENT_PLUGIN_AUTH = 0x00080000,
CLIENT_DEPRECATE_EOF = 0x01000000,
};
enum Command
{
COM_SLEEP = 0x0,
COM_QUIT = 0x1,
COM_INIT_DB = 0x2,
COM_QUERY = 0x3,
COM_FIELD_LIST = 0x4,
COM_CREATE_DB = 0x5,
COM_DROP_DB = 0x6,
COM_REFRESH = 0x7,
COM_SHUTDOWN = 0x8,
COM_STATISTICS = 0x9,
COM_PROCESS_INFO = 0xa,
COM_CONNECT = 0xb,
COM_PROCESS_KILL = 0xc,
COM_DEBUG = 0xd,
COM_PING = 0xe,
COM_TIME = 0xf,
COM_DELAYED_INSERT = 0x10,
COM_CHANGE_USER = 0x11,
COM_RESET_CONNECTION = 0x1f,
COM_DAEMON = 0x1d
};
enum ColumnType
{
MYSQL_TYPE_DECIMAL = 0x00,
MYSQL_TYPE_TINY = 0x01,
MYSQL_TYPE_SHORT = 0x02,
MYSQL_TYPE_LONG = 0x03,
MYSQL_TYPE_FLOAT = 0x04,
MYSQL_TYPE_DOUBLE = 0x05,
MYSQL_TYPE_NULL = 0x06,
MYSQL_TYPE_TIMESTAMP = 0x07,
MYSQL_TYPE_LONGLONG = 0x08,
MYSQL_TYPE_INT24 = 0x09,
MYSQL_TYPE_DATE = 0x0a,
MYSQL_TYPE_TIME = 0x0b,
MYSQL_TYPE_DATETIME = 0x0c,
MYSQL_TYPE_YEAR = 0x0d,
MYSQL_TYPE_VARCHAR = 0x0f,
MYSQL_TYPE_BIT = 0x10,
MYSQL_TYPE_NEWDECIMAL = 0xf6,
MYSQL_TYPE_ENUM = 0xf7,
MYSQL_TYPE_SET = 0xf8,
MYSQL_TYPE_TINY_BLOB = 0xf9,
MYSQL_TYPE_MEDIUM_BLOB = 0xfa,
MYSQL_TYPE_LONG_BLOB = 0xfb,
MYSQL_TYPE_BLOB = 0xfc,
MYSQL_TYPE_VAR_STRING = 0xfd,
MYSQL_TYPE_STRING = 0xfe,
MYSQL_TYPE_GEOMETRY = 0xff
};
class ProtocolError : public DB::Exception
{
public:
using Exception::Exception;
};
class WritePacket
{
public:
virtual String getPayload() const = 0;
virtual ~WritePacket() = default;
};
class ReadPacket
{
public:
ReadPacket() = default;
ReadPacket(const ReadPacket &) = default;
virtual void readPayload(String payload) = 0;
virtual ~ReadPacket() = default;
};
/* Writes and reads packets, keeping sequence-id.
* Throws ProtocolError, if packet with incorrect sequence-id was received.
*/
class PacketSender
{
public:
size_t & sequence_id;
ReadBuffer * in;
WriteBuffer * out;
size_t max_packet_size = MAX_PACKET_LENGTH;
/// For reading and writing.
PacketSender(ReadBuffer & in, WriteBuffer & out, size_t & sequence_id)
: sequence_id(sequence_id)
, in(&in)
, out(&out)
{
}
/// For writing.
PacketSender(WriteBuffer & out, size_t & sequence_id)
: sequence_id(sequence_id)
, in(nullptr)
, out(&out)
{
}
String receivePacketPayload()
{
WriteBufferFromOwnString buf;
size_t payload_length = 0;
size_t packet_sequence_id = 0;
// packets which are larger than or equal to 16MB are splitted
do
{
in->readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > max_packet_size)
{
std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
in->readStrict(reinterpret_cast<char *>(&packet_sequence_id), 1);
if (packet_sequence_id != sequence_id)
{
std::ostringstream tmp;
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << sequence_id << '.';
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
sequence_id++;
copyData(*in, static_cast<WriteBuffer &>(buf), payload_length);
} while (payload_length == max_packet_size);
return std::move(buf.str());
}
void receivePacket(ReadPacket & packet)
{
packet.readPayload(receivePacketPayload());
}
template<class T>
void sendPacket(const T & packet, bool flush = false)
{
static_assert(std::is_base_of<WritePacket, T>());
String payload = packet.getPayload();
size_t pos = 0;
do
{
size_t payload_length = std::min(payload.length() - pos, max_packet_size);
out->write(reinterpret_cast<const char *>(&payload_length), 3);
out->write(reinterpret_cast<const char *>(&sequence_id), 1);
out->write(payload.data() + pos, payload_length);
pos += payload_length;
sequence_id++;
} while (pos < payload.length());
if (flush)
out->next();
}
/// Sets sequence-id to 0. Must be called before each command phase.
void resetSequenceId();
private:
/// Converts packet to text. Is used for debug output.
static String packetToText(String payload);
};
uint64_t readLengthEncodedNumber(std::istringstream & ss);
String writeLengthEncodedNumber(uint64_t x);
void writeLengthEncodedString(String & payload, const String & s);
void writeNulTerminatedString(String & payload, const String & s);
class Handshake : public WritePacket
{
int protocol_version = 0xa;
String server_version;
uint32_t connection_id;
uint32_t capability_flags;
uint8_t character_set;
uint32_t status_flags;
String auth_plugin_data;
public:
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_data)
: protocol_version(0xa)
, server_version(std::move(server_version))
, connection_id(connection_id)
, capability_flags(capability_flags)
, character_set(CharacterSet::utf8_general_ci)
, status_flags(0)
, auth_plugin_data(auth_plugin_data)
{
}
String getPayload() const override
{
String result;
result.append(1, protocol_version);
writeNulTerminatedString(result, server_version);
result.append(reinterpret_cast<const char *>(&connection_id), 4);
writeNulTerminatedString(result, auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH));
result.append(reinterpret_cast<const char *>(&capability_flags), 2);
result.append(reinterpret_cast<const char *>(&character_set), 1);
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result.append((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
result.append(1, auth_plugin_data.size());
result.append(10, 0x0);
result.append(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH));
result.append(Authentication::SHA256);
result.append(1, 0x0);
return result;
}
};
class SSLRequest : public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
void readPayload(String s) override
{
std::istringstream ss(s);
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
}
};
class HandshakeResponse : public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
String username;
String auth_response;
String database;
String auth_plugin_name;
HandshakeResponse() = default;
HandshakeResponse(const HandshakeResponse &) = default;
void readPayload(String s) override
{
std::istringstream ss(s);
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
ss.ignore(23);
std::getline(ss, username, static_cast<char>(0x0));
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
auto len = readLengthEncodedNumber(ss);
auth_response.resize(len);
ss.read(auth_response.data(), static_cast<std::streamsize>(len));
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
uint8_t len;
ss.read(reinterpret_cast<char *>(&len), 1);
auth_response.resize(len);
ss.read(auth_response.data(), len);
}
else
{
std::getline(ss, auth_response, static_cast<char>(0x0));
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
std::getline(ss, database, static_cast<char>(0x0));
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
std::getline(ss, auth_plugin_name, static_cast<char>(0x0));
}
}
};
class AuthSwitchRequest : public WritePacket
{
String plugin_name;
String auth_plugin_data;
public:
AuthSwitchRequest(String plugin_name, String auth_plugin_data)
: plugin_name(std::move(plugin_name)), auth_plugin_data(std::move(auth_plugin_data))
{
}
String getPayload() const override
{
String result;
result.append(1, 0xfe);
writeNulTerminatedString(result, plugin_name);
result.append(auth_plugin_data);
return result;
}
};
class AuthSwitchResponse : public ReadPacket
{
public:
String value;
void readPayload(String s) override
{
value = std::move(s);
}
};
class AuthMoreData : public WritePacket
{
String data;
public:
AuthMoreData(String data): data(std::move(data)) {}
String getPayload() const override
{
String result;
result.append(1, 0x01);
result.append(data);
return result;
}
};
/// Packet with a single null-terminated string. Is used for clear text authentication.
class NullTerminatedString : public ReadPacket
{
public:
String value;
void readPayload(String s) override
{
if (s.length() == 0 || s.back() != 0)
{
throw ProtocolError("String is not null terminated.", ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
value = s;
value.pop_back();
}
};
class OK_Packet : public WritePacket
{
uint8_t header;
uint32_t capabilities;
uint64_t affected_rows;
int16_t warnings = 0;
uint32_t status_flags;
String session_state_changes;
String info;
public:
OK_Packet(uint8_t header,
uint32_t capabilities,
uint64_t affected_rows,
uint32_t status_flags,
int16_t warnings,
String session_state_changes = "",
String info = "")
: header(header)
, capabilities(capabilities)
, affected_rows(affected_rows)
, warnings(warnings)
, status_flags(status_flags)
, session_state_changes(std::move(session_state_changes))
, info(info)
{
}
String getPayload() const override
{
String result;
result.append(1, header);
result.append(writeLengthEncodedNumber(affected_rows));
result.append(writeLengthEncodedNumber(0)); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result.append(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result.append(reinterpret_cast<const char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result.append(writeLengthEncodedNumber(info.length()));
result.append(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
result.append(writeLengthEncodedNumber(session_state_changes.length()));
result.append(session_state_changes);
}
}
else
{
result.append(info);
}
return result;
}
};
class EOF_Packet : public WritePacket
{
int warnings;
int status_flags;
public:
EOF_Packet(int warnings, int status_flags) : warnings(warnings), status_flags(status_flags)
{}
String getPayload() const override
{
String result;
result.append(1, 0xfe); // EOF header
result.append(reinterpret_cast<const char *>(&warnings), 2);
result.append(reinterpret_cast<const char *>(&status_flags), 2);
return result;
}
};
class ERR_Packet : public WritePacket
{
int error_code;
String sql_state;
String error_message;
public:
ERR_Packet(int error_code, String sql_state, String error_message)
: error_code(error_code), sql_state(std::move(sql_state)), error_message(std::move(error_message))
{
}
String getPayload() const override
{
String result;
result.append(1, 0xff);
result.append(reinterpret_cast<const char *>(&error_code), 2);
result.append("#", 1);
result.append(sql_state.data(), sql_state.length());
result.append(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
return result;
}
};
class ColumnDefinition : public WritePacket
{
String schema;
String table;
String org_table;
String name;
String org_name;
size_t next_length = 0x0c;
uint16_t character_set;
uint32_t column_length;
ColumnType column_type;
uint16_t flags;
uint8_t decimals = 0x00;
public:
ColumnDefinition(
String schema,
String table,
String org_table,
String name,
String org_name,
uint16_t character_set,
uint32_t column_length,
ColumnType column_type,
uint16_t flags,
uint8_t decimals)
: schema(std::move(schema)), table(std::move(table)), org_table(std::move(org_table)), name(std::move(name)),
org_name(std::move(org_name)), character_set(character_set), column_length(column_length), column_type(column_type), flags(flags),
decimals(decimals)
{
}
/// Should be used when column metadata (original name, table, original table, database) is unknown.
ColumnDefinition(
String name,
uint16_t character_set,
uint32_t column_length,
ColumnType column_type,
uint16_t flags,
uint8_t decimals)
: ColumnDefinition("", "", "", std::move(name), "", character_set, column_length, column_type, flags, decimals)
{
}
String getPayload() const override
{
String result;
writeLengthEncodedString(result, "def"); /// always "def"
writeLengthEncodedString(result, ""); /// schema
writeLengthEncodedString(result, ""); /// table
writeLengthEncodedString(result, ""); /// org_table
writeLengthEncodedString(result, name);
writeLengthEncodedString(result, ""); /// org_name
result.append(writeLengthEncodedNumber(next_length));
result.append(reinterpret_cast<const char *>(&character_set), 2);
result.append(reinterpret_cast<const char *>(&column_length), 4);
result.append(reinterpret_cast<const char *>(&column_type), 1);
result.append(reinterpret_cast<const char *>(&flags), 2);
result.append(reinterpret_cast<const char *>(&decimals), 2);
result.append(2, 0x0);
return result;
}
};
class ComFieldList : public ReadPacket
{
public:
String table, field_wildcard;
void readPayload(String payload)
{
std::istringstream ss(payload);
ss.ignore(1); // command byte
std::getline(ss, table, static_cast<char>(0x0));
field_wildcard = payload.substr(table.length() + 2); // rest of the packet
}
};
class LengthEncodedNumber : public WritePacket
{
uint64_t value;
public:
LengthEncodedNumber(uint64_t value): value(value)
{
}
String getPayload() const override
{
return writeLengthEncodedNumber(value);
}
};
class ResultsetRow : public WritePacket
{
std::vector<String> columns;
public:
ResultsetRow()
{
}
void appendColumn(String value)
{
columns.emplace_back(std::move(value));
}
String getPayload() const override
{
String result;
for (const String & column : columns)
{
writeLengthEncodedString(result, column);
}
return result;
}
};
}
}

View File

@ -170,6 +170,46 @@ template <> struct NativeType<Decimal32> { using Type = Int32; };
template <> struct NativeType<Decimal64> { using Type = Int64; };
template <> struct NativeType<Decimal128> { using Type = Int128; };
inline const char * getTypeName(TypeIndex idx)
{
switch (idx)
{
case TypeIndex::Nothing: return "Nothing";
case TypeIndex::UInt8: return TypeName<UInt8>::get();
case TypeIndex::UInt16: return TypeName<UInt16>::get();
case TypeIndex::UInt32: return TypeName<UInt32>::get();
case TypeIndex::UInt64: return TypeName<UInt64>::get();
case TypeIndex::UInt128: return "UInt128";
case TypeIndex::Int8: return TypeName<Int8>::get();
case TypeIndex::Int16: return TypeName<Int16>::get();
case TypeIndex::Int32: return TypeName<Int32>::get();
case TypeIndex::Int64: return TypeName<Int64>::get();
case TypeIndex::Int128: return TypeName<Int128>::get();
case TypeIndex::Float32: return TypeName<Float32>::get();
case TypeIndex::Float64: return TypeName<Float64>::get();
case TypeIndex::Date: return "Date";
case TypeIndex::DateTime: return "DateTime";
case TypeIndex::String: return TypeName<String>::get();
case TypeIndex::FixedString: return "FixedString";
case TypeIndex::Enum8: return "Enum8";
case TypeIndex::Enum16: return "Enum16";
case TypeIndex::Decimal32: return TypeName<Decimal32>::get();
case TypeIndex::Decimal64: return TypeName<Decimal64>::get();
case TypeIndex::Decimal128: return TypeName<Decimal128>::get();
case TypeIndex::UUID: return "UUID";
case TypeIndex::Array: return "Array";
case TypeIndex::Tuple: return "Tuple";
case TypeIndex::Set: return "Set";
case TypeIndex::Interval: return "Interval";
case TypeIndex::Nullable: return "Nullable";
case TypeIndex::Function: return "Function";
case TypeIndex::AggregateFunction: return "AggregateFunction";
case TypeIndex::LowCardinality: return "LowCardinality";
}
__builtin_unreachable();
}
}
/// Specialization of `std::hash` for the Decimal<T> types.

View File

@ -10,6 +10,7 @@
#include <Functions/IFunction.h>
#include <IO/WriteBufferFromOStream.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h>
#include <Parsers/IAST.h>
#include <Storages/IStorage.h>
#include <Common/COW.h>
@ -70,7 +71,7 @@ std::ostream & operator<<(std::ostream & stream, const Block & what)
std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what)
{
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << what.type << ", column = ";
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << *what.type << ", column = ";
return dumpValue(stream, what.column) << ")";
}
@ -109,4 +110,56 @@ std::ostream & operator<<(std::ostream & stream, const IAST & what)
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what)
{
stream << "ExpressionAction(" << what.toString() << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what)
{
stream << "ExpressionActions(" << what.dumpActions() << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const SyntaxAnalyzerResult & what)
{
stream << "SyntaxAnalyzerResult{";
stream << "storage=" << what.storage << "; ";
if (!what.source_columns.empty())
{
stream << "source_columns=";
dumpValue(stream, what.source_columns);
stream << "; ";
}
if (!what.aliases.empty())
{
stream << "aliases=";
dumpValue(stream, what.aliases);
stream << "; ";
}
if (!what.array_join_result_to_source.empty())
{
stream << "array_join_result_to_source=";
dumpValue(stream, what.array_join_result_to_source);
stream << "; ";
}
if (!what.array_join_alias_to_name.empty())
{
stream << "array_join_alias_to_name=";
dumpValue(stream, what.array_join_alias_to_name);
stream << "; ";
}
if (!what.array_join_name_to_alias.empty())
{
stream << "array_join_name_to_alias=";
dumpValue(stream, what.array_join_name_to_alias);
stream << "; ";
}
stream << "rewrite_subqueries=" << what.rewrite_subqueries << "; ";
stream << "}";
return stream;
}
}

View File

@ -41,6 +41,14 @@ std::ostream & operator<<(std::ostream & stream, const IAST & what);
std::ostream & operator<<(std::ostream & stream, const Connection::Packet & what);
struct ExpressionAction;
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what);
class ExpressionActions;
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what);
struct SyntaxAnalyzerResult;
std::ostream & operator<<(std::ostream & stream, const SyntaxAnalyzerResult & what);
}
/// some operator<< should be declared before operator<<(... std::shared_ptr<>)

View File

@ -19,8 +19,8 @@ void CountingBlockOutputStream::write(const Block & block)
Progress local_progress(block.rows(), block.bytes(), 0);
progress.incrementPiecewiseAtomically(local_progress);
ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.rows);
ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.bytes);
ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.read_rows);
ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.read_bytes);
if (process_elem)
process_elem->updateProgressOut(local_progress);

View File

@ -281,7 +281,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
/// The total amount of data processed or intended for processing in all leaf sources, possibly on remote servers.
ProgressValues progress = process_list_elem->getProgressIn();
size_t total_rows_estimate = std::max(progress.rows, progress.total_rows);
size_t total_rows_estimate = std::max(progress.read_rows, progress.total_rows_to_read);
/** Check the restrictions on the amount of data to read, the speed of the query, the quota on the amount of data to read.
* NOTE: Maybe it makes sense to have them checked directly in ProcessList?
@ -289,7 +289,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
if (limits.mode == LIMITS_TOTAL
&& ((limits.size_limits.max_rows && total_rows_estimate > limits.size_limits.max_rows)
|| (limits.size_limits.max_bytes && progress.bytes > limits.size_limits.max_bytes)))
|| (limits.size_limits.max_bytes && progress.read_bytes > limits.size_limits.max_bytes)))
{
switch (limits.size_limits.overflow_mode)
{
@ -300,7 +300,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
+ " rows read (or to read), maximum: " + toString(limits.size_limits.max_rows),
ErrorCodes::TOO_MANY_ROWS);
else
throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.bytes)
throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.read_bytes)
+ " bytes read, maximum: " + toString(limits.size_limits.max_bytes),
ErrorCodes::TOO_MANY_BYTES);
}
@ -308,8 +308,8 @@ void IBlockInputStream::progressImpl(const Progress & value)
case OverflowMode::BREAK:
{
/// For `break`, we will stop only if so many rows were actually read, and not just supposed to be read.
if ((limits.size_limits.max_rows && progress.rows > limits.size_limits.max_rows)
|| (limits.size_limits.max_bytes && progress.bytes > limits.size_limits.max_bytes))
if ((limits.size_limits.max_rows && progress.read_rows > limits.size_limits.max_rows)
|| (limits.size_limits.max_bytes && progress.read_bytes > limits.size_limits.max_bytes))
{
cancel(false);
}
@ -322,7 +322,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
}
}
size_t total_rows = progress.total_rows;
size_t total_rows = progress.total_rows_to_read;
constexpr UInt64 profile_events_update_period_microseconds = 10 * 1000; // 10 milliseconds
UInt64 total_elapsed_microseconds = info.total_stopwatch.elapsedMicroseconds();
@ -344,20 +344,20 @@ void IBlockInputStream::progressImpl(const Progress & value)
if (elapsed_seconds > 0)
{
if (limits.min_execution_speed && progress.rows / elapsed_seconds < limits.min_execution_speed)
throw Exception("Query is executing too slow: " + toString(progress.rows / elapsed_seconds)
if (limits.min_execution_speed && progress.read_rows / elapsed_seconds < limits.min_execution_speed)
throw Exception("Query is executing too slow: " + toString(progress.read_rows / elapsed_seconds)
+ " rows/sec., minimum: " + toString(limits.min_execution_speed),
ErrorCodes::TOO_SLOW);
if (limits.min_execution_speed_bytes && progress.bytes / elapsed_seconds < limits.min_execution_speed_bytes)
throw Exception("Query is executing too slow: " + toString(progress.bytes / elapsed_seconds)
if (limits.min_execution_speed_bytes && progress.read_bytes / elapsed_seconds < limits.min_execution_speed_bytes)
throw Exception("Query is executing too slow: " + toString(progress.read_bytes / elapsed_seconds)
+ " bytes/sec., minimum: " + toString(limits.min_execution_speed_bytes),
ErrorCodes::TOO_SLOW);
/// If the predicted execution time is longer than `max_execution_time`.
if (limits.max_execution_time != 0 && total_rows)
{
double estimated_execution_time_seconds = elapsed_seconds * (static_cast<double>(total_rows) / progress.rows);
double estimated_execution_time_seconds = elapsed_seconds * (static_cast<double>(total_rows) / progress.read_rows);
if (estimated_execution_time_seconds > limits.max_execution_time.totalSeconds())
throw Exception("Estimated query execution time (" + toString(estimated_execution_time_seconds) + " seconds)"
@ -366,17 +366,17 @@ void IBlockInputStream::progressImpl(const Progress & value)
ErrorCodes::TOO_SLOW);
}
if (limits.max_execution_speed && progress.rows / elapsed_seconds >= limits.max_execution_speed)
limitProgressingSpeed(progress.rows, limits.max_execution_speed, total_elapsed_microseconds);
if (limits.max_execution_speed && progress.read_rows / elapsed_seconds >= limits.max_execution_speed)
limitProgressingSpeed(progress.read_rows, limits.max_execution_speed, total_elapsed_microseconds);
if (limits.max_execution_speed_bytes && progress.bytes / elapsed_seconds >= limits.max_execution_speed_bytes)
limitProgressingSpeed(progress.bytes, limits.max_execution_speed_bytes, total_elapsed_microseconds);
if (limits.max_execution_speed_bytes && progress.read_bytes / elapsed_seconds >= limits.max_execution_speed_bytes)
limitProgressingSpeed(progress.read_bytes, limits.max_execution_speed_bytes, total_elapsed_microseconds);
}
}
if (quota != nullptr && limits.mode == LIMITS_TOTAL)
{
quota->checkAndAddReadRowsBytes(time(nullptr), value.rows, value.bytes);
quota->checkAndAddReadRowsBytes(time(nullptr), value.read_rows, value.read_bytes);
}
}
}

View File

@ -36,6 +36,7 @@ public:
bool canBeInsideNullable() const override { return false; }
DataTypePtr getReturnType() const { return function->getReturnType(); }
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
DataTypes getArgumentsDataTypes() const { return argument_types; }
/// NOTE These two functions for serializing single values are incompatible with the functions below.

View File

@ -821,7 +821,7 @@ MutableColumnUniquePtr DataTypeLowCardinality::createColumnUniqueImpl(const IDat
return creator(static_cast<ColumnVector<UInt16> *>(nullptr));
if (typeid_cast<const DataTypeDateTime *>(type))
return creator(static_cast<ColumnVector<UInt32> *>(nullptr));
if (isNumber(type))
if (isColumnedAsNumber(type))
{
MutableColumnUniquePtr column;
TypeListNumbers::forEach(CreateColumnVector(column, *type, creator));

View File

@ -581,11 +581,18 @@ inline bool isFloat(const T & data_type)
return which.isFloat();
}
template <typename T>
inline bool isNativeNumber(const T & data_type)
{
WhichDataType which(data_type);
return which.isNativeInt() || which.isNativeUInt() || which.isFloat();
}
template <typename T>
inline bool isNumber(const T & data_type)
{
WhichDataType which(data_type);
return which.isInt() || which.isUInt() || which.isFloat();
return which.isInt() || which.isUInt() || which.isFloat() || which.isDecimal();
}
template <typename T>

View File

@ -86,7 +86,7 @@ void CacheDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<K
{
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItemsNumber<UInt64>(*hierarchical_attribute, ids, out, [&](const size_t) { return null_value; });
getItemsNumberImpl<UInt64, UInt64>(*hierarchical_attribute, ids, out, [&](const size_t) { return null_value; });
}
@ -207,9 +207,7 @@ void CacheDictionary::isInConstantVector(const Key child_id, const PaddedPODArra
void CacheDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
@ -220,9 +218,7 @@ void CacheDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, ids, out, [&](const size_t row) { return def->getDataAt(row); });
}
@ -231,9 +227,7 @@ void CacheDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, ids, out, [&](const size_t) { return StringRef{def}; });
}

View File

@ -221,11 +221,6 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename DefaultGetter>
void getItemsNumber(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const;
template <typename AttributeType, typename OutputType, typename DefaultGetter>
void getItemsNumberImpl(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const;

View File

@ -34,34 +34,6 @@ namespace ErrorCodes
extern const int TYPE_MISMATCH;
}
template <typename OutputType, typename DefaultGetter>
void CacheDictionary::getItemsNumber(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) \
getItemsNumberImpl<TYPE, OutputType>(attribute, ids, out, std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename DefaultGetter>
void CacheDictionary::getItemsNumberImpl(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const

View File

@ -12,13 +12,11 @@ using TYPE = @NAME@;
void CacheDictionary::get@NAME@(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ResultArrayType<TYPE> & out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
const auto null_value = std::get<TYPE>(attribute.null_values);
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t) { return null_value; });
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t) { return null_value; });
}
}

View File

@ -15,11 +15,9 @@ void CacheDictionary::get@NAME@(const std::string & attribute_name,
ResultArrayType<TYPE> & out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t row) { return def[row]; });
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t row) { return def[row]; });
}
}

View File

@ -12,11 +12,9 @@ using TYPE = @NAME@;
void CacheDictionary::get@NAME@(const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE def, ResultArrayType<TYPE> & out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t) { return def; });
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t) { return def; });
}
}

View File

@ -77,9 +77,7 @@ void ComplexKeyCacheDictionary::getString(
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
@ -96,9 +94,7 @@ void ComplexKeyCacheDictionary::getString(
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, key_columns, out, [&](const size_t row) { return def->getDataAt(row); });
}
@ -113,9 +109,7 @@ void ComplexKeyCacheDictionary::getString(
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, key_columns, out, [&](const size_t) { return StringRef{def}; });
}

View File

@ -256,34 +256,6 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename DefaultGetter>
void
getItemsNumber(Attribute & attribute, const Columns & key_columns, PaddedPODArray<OutputType> & out, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) \
getItemsNumberImpl<TYPE, OutputType>(attribute, key_columns, out, std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename DefaultGetter>
void getItemsNumberImpl(
Attribute & attribute, const Columns & key_columns, PaddedPODArray<OutputType> & out, DefaultGetter && get_default) const

View File

@ -13,12 +13,10 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name, co
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
const auto null_value = std::get<TYPE>(attribute.null_values);
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t) { return null_value; });
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t) { return null_value; });
}
}

View File

@ -18,10 +18,8 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name,
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t row) { return def[row]; });
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t row) { return def[row]; });
}
}

View File

@ -18,10 +18,8 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name,
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t) { return def; });
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t) { return def; });
}
}

View File

@ -50,13 +50,11 @@ ComplexKeyHashedDictionary::ComplexKeyHashedDictionary(
dict_struct.validateKeyTypes(key_types); \
\
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
const auto null_value = std::get<TYPE>(attribute.null_values); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, \
key_columns, \
[&](const size_t row, const auto value) { out[row] = value; }, \
@ -84,9 +82,7 @@ void ComplexKeyHashedDictionary::getString(
dict_struct.validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
@ -108,11 +104,9 @@ void ComplexKeyHashedDictionary::getString(
dict_struct.validateKeyTypes(key_types); \
\
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, \
key_columns, \
[&](const size_t row, const auto value) { out[row] = value; }, \
@ -144,9 +138,7 @@ void ComplexKeyHashedDictionary::getString(
dict_struct.validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -166,11 +158,9 @@ void ComplexKeyHashedDictionary::getString(
dict_struct.validateKeyTypes(key_types); \
\
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, key_columns, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
}
DECLARE(UInt8)
@ -199,9 +189,7 @@ void ComplexKeyHashedDictionary::getString(
dict_struct.validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -566,34 +554,6 @@ ComplexKeyHashedDictionary::createAttributeWithType(const AttributeUnderlyingTyp
}
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void ComplexKeyHashedDictionary::getItemsNumber(
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) getItemsImpl<TYPE, OutputType>( \
attribute, key_columns, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void ComplexKeyHashedDictionary::getItemsImpl(
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const

View File

@ -218,16 +218,10 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void
getItemsNumber(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void
getItemsImpl(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename T>
bool setAttributeValueImpl(Attribute & attribute, const StringRef key, const T value);

View File

@ -40,44 +40,6 @@ namespace
} // namespace
bool isAttributeTypeConvertibleTo(AttributeUnderlyingType from, AttributeUnderlyingType to)
{
if (from == to)
return true;
/** This enum can be somewhat incomplete and the meaning may not coincide with NumberTraits.h.
* (for example, because integers can not be converted to floats)
* This is normal for a limited usage scope.
*/
if ((from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::UInt16)
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::UInt32)
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::UInt64)
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::UInt32)
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::UInt64)
|| (from == AttributeUnderlyingType::UInt32 && to == AttributeUnderlyingType::UInt64)
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::Int16)
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::Int32)
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::Int64)
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::Int32)
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::Int64)
|| (from == AttributeUnderlyingType::UInt32 && to == AttributeUnderlyingType::Int64)
|| (from == AttributeUnderlyingType::Int8 && to == AttributeUnderlyingType::Int16)
|| (from == AttributeUnderlyingType::Int8 && to == AttributeUnderlyingType::Int32)
|| (from == AttributeUnderlyingType::Int8 && to == AttributeUnderlyingType::Int64)
|| (from == AttributeUnderlyingType::Int16 && to == AttributeUnderlyingType::Int32)
|| (from == AttributeUnderlyingType::Int16 && to == AttributeUnderlyingType::Int64)
|| (from == AttributeUnderlyingType::Int32 && to == AttributeUnderlyingType::Int64)
|| (from == AttributeUnderlyingType::Float32 && to == AttributeUnderlyingType::Float64))
{
return true;
}
return false;
}
AttributeUnderlyingType getAttributeUnderlyingType(const std::string & type)
{
static const std::unordered_map<std::string, AttributeUnderlyingType> dictionary{

View File

@ -13,6 +13,12 @@
namespace DB
{
namespace ErrorCodes
{
extern const int TYPE_MISMATCH;
}
enum class AttributeUnderlyingType
{
UInt8,
@ -33,14 +39,18 @@ enum class AttributeUnderlyingType
};
/** For implicit conversions in dictGet functions.
*/
bool isAttributeTypeConvertibleTo(AttributeUnderlyingType from, AttributeUnderlyingType to);
AttributeUnderlyingType getAttributeUnderlyingType(const std::string & type);
std::string toString(const AttributeUnderlyingType type);
/// Implicit conversions in dictGet functions is disabled.
inline void checkAttributeType(const std::string & dict_name, const std::string & attribute_name,
AttributeUnderlyingType attribute_type, AttributeUnderlyingType to)
{
if (attribute_type != to)
throw Exception{dict_name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute_type)
+ ", expected " + toString(to), ErrorCodes::TYPE_MISMATCH};
}
/// Min and max lifetimes for a dictionary or it's entry
using DictionaryLifetime = ExternalLoadableLifetime;

View File

@ -55,7 +55,7 @@ void FlatDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<Ke
{
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItemsNumber<UInt64>(
getItemsImpl<UInt64, UInt64>(
*hierarchical_attribute,
ids,
[&](const size_t row, const UInt64 value) { out[row] = value; },
@ -117,13 +117,11 @@ void FlatDictionary::isInConstantVector(const Key child_id, const PaddedPODArray
void FlatDictionary::get##TYPE(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ResultArrayType<TYPE> & out) const \
{ \
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
const auto null_value = std::get<TYPE>(attribute.null_values); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return null_value; }); \
}
DECLARE(UInt8)
@ -145,9 +143,7 @@ DECLARE(Decimal128)
void FlatDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto & null_value = std::get<StringRef>(attribute.null_values);
@ -166,11 +162,9 @@ void FlatDictionary::getString(const std::string & attribute_name, const PaddedP
ResultArrayType<TYPE> & out) const \
{ \
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t row) { return def[row]; }); \
}
DECLARE(UInt8)
@ -193,9 +187,7 @@ void FlatDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -209,11 +201,9 @@ void FlatDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE def, ResultArrayType<TYPE> & out) const \
{ \
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
}
DECLARE(UInt8)
@ -236,9 +226,7 @@ void FlatDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
FlatDictionary::getItemsImpl<StringRef, StringRef>(
attribute,
@ -580,35 +568,6 @@ FlatDictionary::Attribute FlatDictionary::createAttributeWithType(const Attribut
}
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void FlatDictionary::getItemsNumber(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) \
getItemsImpl<TYPE, OutputType>(attribute, ids, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void FlatDictionary::getItemsImpl(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
@ -701,10 +660,10 @@ void FlatDictionary::setAttributeValue(Attribute & attribute, const Key id, cons
break;
case AttributeUnderlyingType::Decimal32:
setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal128>());
setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal32>());
break;
case AttributeUnderlyingType::Decimal64:
setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal128>());
setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal64>());
break;
case AttributeUnderlyingType::Decimal128:
setAttributeValueImpl<Decimal128>(attribute, id, value.get<Decimal128>());

View File

@ -206,10 +206,6 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void getItemsNumber(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void getItemsImpl(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;

View File

@ -49,7 +49,7 @@ void HashedDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<
{
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItemsNumber<UInt64>(
getItemsImpl<UInt64, UInt64>(
*hierarchical_attribute,
ids,
[&](const size_t row, const UInt64 value) { out[row] = value; },
@ -116,13 +116,11 @@ void HashedDictionary::isInConstantVector(const Key child_id, const PaddedPODArr
const \
{ \
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
const auto null_value = std::get<TYPE>(attribute.null_values); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return null_value; }); \
}
DECLARE(UInt8)
@ -144,9 +142,7 @@ DECLARE(Decimal128)
void HashedDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
@ -165,11 +161,9 @@ void HashedDictionary::getString(const std::string & attribute_name, const Padde
ResultArrayType<TYPE> & out) const \
{ \
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t row) { return def[row]; }); \
}
DECLARE(UInt8)
@ -192,9 +186,7 @@ void HashedDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -208,11 +200,9 @@ void HashedDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE & def, ResultArrayType<TYPE> & out) const \
{ \
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
}
DECLARE(UInt8)
@ -235,9 +225,7 @@ void HashedDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -324,7 +312,6 @@ void HashedDictionary::createAttributes()
void HashedDictionary::blockToAttributes(const Block & block)
{
const auto & id_column = *block.safeGetByPosition(0).column;
element_count += id_column.size();
for (const size_t attribute_idx : ext::range(0, attributes.size()))
{
@ -332,7 +319,8 @@ void HashedDictionary::blockToAttributes(const Block & block)
auto & attribute = attributes[attribute_idx];
for (const auto row_idx : ext::range(0, id_column.size()))
setAttributeValue(attribute, id_column[row_idx].get<UInt64>(), attribute_column[row_idx]);
if (setAttributeValue(attribute, id_column[row_idx].get<UInt64>(), attribute_column[row_idx]))
++element_count;
}
}
@ -567,34 +555,6 @@ HashedDictionary::Attribute HashedDictionary::createAttributeWithType(const Attr
}
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void HashedDictionary::getItemsNumber(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) \
getItemsImpl<TYPE, OutputType>(attribute, ids, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void HashedDictionary::getItemsImpl(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
@ -613,69 +573,56 @@ void HashedDictionary::getItemsImpl(
template <typename T>
void HashedDictionary::setAttributeValueImpl(Attribute & attribute, const Key id, const T value)
bool HashedDictionary::setAttributeValueImpl(Attribute & attribute, const Key id, const T value)
{
auto & map = *std::get<CollectionPtrType<T>>(attribute.maps);
map.insert({id, value});
return map.insert({id, value}).second;
}
void HashedDictionary::setAttributeValue(Attribute & attribute, const Key id, const Field & value)
bool HashedDictionary::setAttributeValue(Attribute & attribute, const Key id, const Field & value)
{
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8:
setAttributeValueImpl<UInt8>(attribute, id, value.get<UInt64>());
break;
return setAttributeValueImpl<UInt8>(attribute, id, value.get<UInt64>());
case AttributeUnderlyingType::UInt16:
setAttributeValueImpl<UInt16>(attribute, id, value.get<UInt64>());
break;
return setAttributeValueImpl<UInt16>(attribute, id, value.get<UInt64>());
case AttributeUnderlyingType::UInt32:
setAttributeValueImpl<UInt32>(attribute, id, value.get<UInt64>());
break;
return setAttributeValueImpl<UInt32>(attribute, id, value.get<UInt64>());
case AttributeUnderlyingType::UInt64:
setAttributeValueImpl<UInt64>(attribute, id, value.get<UInt64>());
break;
return setAttributeValueImpl<UInt64>(attribute, id, value.get<UInt64>());
case AttributeUnderlyingType::UInt128:
setAttributeValueImpl<UInt128>(attribute, id, value.get<UInt128>());
break;
return setAttributeValueImpl<UInt128>(attribute, id, value.get<UInt128>());
case AttributeUnderlyingType::Int8:
setAttributeValueImpl<Int8>(attribute, id, value.get<Int64>());
break;
return setAttributeValueImpl<Int8>(attribute, id, value.get<Int64>());
case AttributeUnderlyingType::Int16:
setAttributeValueImpl<Int16>(attribute, id, value.get<Int64>());
break;
return setAttributeValueImpl<Int16>(attribute, id, value.get<Int64>());
case AttributeUnderlyingType::Int32:
setAttributeValueImpl<Int32>(attribute, id, value.get<Int64>());
break;
return setAttributeValueImpl<Int32>(attribute, id, value.get<Int64>());
case AttributeUnderlyingType::Int64:
setAttributeValueImpl<Int64>(attribute, id, value.get<Int64>());
break;
return setAttributeValueImpl<Int64>(attribute, id, value.get<Int64>());
case AttributeUnderlyingType::Float32:
setAttributeValueImpl<Float32>(attribute, id, value.get<Float64>());
break;
return setAttributeValueImpl<Float32>(attribute, id, value.get<Float64>());
case AttributeUnderlyingType::Float64:
setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>());
break;
return setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>());
case AttributeUnderlyingType::Decimal32:
setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal32>());
break;
return setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal32>());
case AttributeUnderlyingType::Decimal64:
setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal64>());
break;
return setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal64>());
case AttributeUnderlyingType::Decimal128:
setAttributeValueImpl<Decimal128>(attribute, id, value.get<Decimal128>());
break;
return setAttributeValueImpl<Decimal128>(attribute, id, value.get<Decimal128>());
case AttributeUnderlyingType::String:
{
auto & map = *std::get<CollectionPtrType<StringRef>>(attribute.maps);
const auto & string = value.get<String>();
const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size());
map.insert({id, StringRef{string_in_arena, string.size()}});
break;
return map.insert({id, StringRef{string_in_arena, string.size()}}).second;
}
}
throw Exception{"Invalid attribute type", ErrorCodes::BAD_ARGUMENTS};
}
const HashedDictionary::Attribute & HashedDictionary::getAttribute(const std::string & attribute_name) const

View File

@ -211,18 +211,14 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void getItemsNumber(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void getItemsImpl(
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename T>
void setAttributeValueImpl(Attribute & attribute, const Key id, const T value);
bool setAttributeValueImpl(Attribute & attribute, const Key id, const T value);
void setAttributeValue(Attribute & attribute, const Key id, const Field & value);
bool setAttributeValue(Attribute & attribute, const Key id, const Field & value);
const Attribute & getAttribute(const std::string & attribute_name) const;

View File

@ -6,6 +6,7 @@
#include "DictionaryStructure.h"
namespace DB
{
namespace ErrorCodes
@ -47,7 +48,6 @@ void registerDictionarySourceMysql(DictionarySourceFactory & factory)
# include <Formats/MySQLBlockInputStream.h>
# include "readInvalidateQuery.h"
namespace DB
{
static const UInt64 max_block_size = 8192;
@ -71,6 +71,7 @@ MySQLDictionarySource::MySQLDictionarySource(
, query_builder{dict_struct, db, table, where, IdentifierQuotingStyle::Backticks}
, load_all_query{query_builder.composeLoadAllQuery()}
, invalidate_query{config.getString(config_prefix + ".invalidate_query", "")}
, close_connection{config.getBool(config_prefix + ".close_connection", false)}
{
}
@ -91,6 +92,7 @@ MySQLDictionarySource::MySQLDictionarySource(const MySQLDictionarySource & other
, last_modification{other.last_modification}
, invalidate_query{other.invalidate_query}
, invalidate_query_response{other.invalidate_query_response}
, close_connection{other.close_connection}
{
}
@ -117,7 +119,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadAll()
last_modification = getLastModification();
LOG_TRACE(log, load_all_query);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_all_query, sample_block, max_block_size);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_all_query, sample_block, max_block_size, close_connection);
}
BlockInputStreamPtr MySQLDictionarySource::loadUpdatedAll()
@ -126,7 +128,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadUpdatedAll()
std::string load_update_query = getUpdateFieldAndDate();
LOG_TRACE(log, load_update_query);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_update_query, sample_block, max_block_size);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_update_query, sample_block, max_block_size, close_connection);
}
BlockInputStreamPtr MySQLDictionarySource::loadIds(const std::vector<UInt64> & ids)
@ -134,7 +136,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadIds(const std::vector<UInt64> & i
/// We do not log in here and do not update the modification time, as the request can be large, and often called.
const auto query = query_builder.composeLoadIdsQuery(ids);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size, close_connection);
}
BlockInputStreamPtr MySQLDictionarySource::loadKeys(const Columns & key_columns, const std::vector<size_t> & requested_rows)
@ -142,7 +144,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadKeys(const Columns & key_columns,
/// We do not log in here and do not update the modification time, as the request can be large, and often called.
const auto query = query_builder.composeLoadKeysQuery(key_columns, requested_rows, ExternalQueryBuilder::AND_OR_CHAIN);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size);
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size, close_connection);
}
bool MySQLDictionarySource::isModified() const
@ -253,7 +255,7 @@ std::string MySQLDictionarySource::doInvalidateQuery(const std::string & request
Block invalidate_sample_block;
ColumnPtr column(ColumnString::create());
invalidate_sample_block.insert(ColumnWithTypeAndName(column, std::make_shared<DataTypeString>(), "Sample Block"));
MySQLBlockInputStream block_input_stream(pool.Get(), request, invalidate_sample_block, 1);
MySQLBlockInputStream block_input_stream(pool.Get(), request, invalidate_sample_block, 1, close_connection);
return readInvalidateQuery(block_input_stream);
}

View File

@ -81,6 +81,7 @@ private:
LocalDateTime last_modification;
std::string invalidate_query;
mutable std::string invalidate_query_response;
const bool close_connection;
};
}

View File

@ -75,13 +75,11 @@ TrieDictionary::~TrieDictionary()
validateKeyTypes(key_types); \
\
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
const auto null_value = std::get<TYPE>(attribute.null_values); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, \
key_columns, \
[&](const size_t row, const auto value) { out[row] = value; }, \
@ -109,9 +107,7 @@ void TrieDictionary::getString(
validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
@ -133,11 +129,9 @@ void TrieDictionary::getString(
validateKeyTypes(key_types); \
\
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, \
key_columns, \
[&](const size_t row, const auto value) { out[row] = value; }, \
@ -169,9 +163,7 @@ void TrieDictionary::getString(
validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -191,11 +183,9 @@ void TrieDictionary::getString(
validateKeyTypes(key_types); \
\
const auto & attribute = getAttribute(attribute_name); \
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
ErrorCodes::TYPE_MISMATCH}; \
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
\
getItemsNumber<TYPE>( \
getItemsImpl<TYPE, TYPE>( \
attribute, key_columns, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
}
DECLARE(UInt8)
@ -224,9 +214,7 @@ void TrieDictionary::getString(
validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsImpl<StringRef, StringRef>(
attribute,
@ -507,34 +495,6 @@ TrieDictionary::Attribute TrieDictionary::createAttributeWithType(const Attribut
}
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void TrieDictionary::getItemsNumber(
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) getItemsImpl<TYPE, OutputType>( \
attribute, key_columns, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void TrieDictionary::getItemsImpl(
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const

View File

@ -218,10 +218,6 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
void
getItemsNumber(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
void
getItemsImpl(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;

View File

@ -340,7 +340,7 @@ bool OPTIMIZE(1) CSVRowInputStream::parseRowAndPrintDiagnosticInfo(MutableColumn
if (curr_position < prev_position)
throw Exception("Logical error: parsing is non-deterministic.", ErrorCodes::LOGICAL_ERROR);
if (isNumber(current_column_type) || isDateOrDateTime(current_column_type))
if (isNativeNumber(current_column_type) || isDateOrDateTime(current_column_type))
{
/// An empty string instead of a value.
if (curr_position == prev_position)

View File

@ -130,6 +130,7 @@ void registerOutputFormatXML(FormatFactory & factory);
void registerOutputFormatODBCDriver(FormatFactory & factory);
void registerOutputFormatODBCDriver2(FormatFactory & factory);
void registerOutputFormatNull(FormatFactory & factory);
void registerOutputFormatMySQLWire(FormatFactory & factory);
/// Input only formats.
@ -168,6 +169,7 @@ FormatFactory::FormatFactory()
registerOutputFormatODBCDriver(*this);
registerOutputFormatODBCDriver2(*this);
registerOutputFormatNull(*this);
registerOutputFormatMySQLWire(*this);
}
}

View File

@ -220,10 +220,10 @@ void JSONRowOutputStream::writeStatistics()
writeText(watch.elapsedSeconds(), *ostr);
writeCString(",\n", *ostr);
writeCString("\t\t\"rows_read\": ", *ostr);
writeText(progress.rows.load(), *ostr);
writeText(progress.read_rows.load(), *ostr);
writeCString(",\n", *ostr);
writeCString("\t\t\"bytes_read\": ", *ostr);
writeText(progress.bytes.load(), *ostr);
writeText(progress.read_bytes.load(), *ostr);
writeChar('\n', *ostr);
writeCString("\t}", *ostr);

View File

@ -20,8 +20,8 @@ namespace ErrorCodes
MySQLBlockInputStream::MySQLBlockInputStream(
const mysqlxx::PoolWithFailover::Entry & entry, const std::string & query_str, const Block & sample_block, const UInt64 max_block_size)
: entry{entry}, query{this->entry->query(query_str)}, result{query.use()}, max_block_size{max_block_size}
const mysqlxx::PoolWithFailover::Entry & entry, const std::string & query_str, const Block & sample_block, const UInt64 max_block_size, const bool auto_close)
: entry{entry}, query{this->entry->query(query_str)}, result{query.use()}, max_block_size{max_block_size}, auto_close{auto_close}
{
if (sample_block.columns() != result.getNumFields())
throw Exception{"mysqlxx::UseQueryResult contains " + toString(result.getNumFields()) + " columns while "
@ -93,7 +93,11 @@ Block MySQLBlockInputStream::readImpl()
{
auto row = result.fetch();
if (!row)
{
if (auto_close)
entry.disconnect();
return {};
}
MutableColumns columns(description.sample_block.columns());
for (const auto i : ext::range(0, columns.size()))
@ -126,7 +130,8 @@ Block MySQLBlockInputStream::readImpl()
row = result.fetch();
}
if (auto_close)
entry.disconnect();
return description.sample_block.cloneWithColumns(std::move(columns));
}

View File

@ -18,7 +18,8 @@ public:
const mysqlxx::PoolWithFailover::Entry & entry,
const std::string & query_str,
const Block & sample_block,
const UInt64 max_block_size);
const UInt64 max_block_size,
const bool auto_close = false);
String getName() const override { return "MySQL"; }
@ -31,6 +32,7 @@ private:
mysqlxx::Query query;
mysqlxx::UseQueryResult result;
const UInt64 max_block_size;
const bool auto_close;
ExternalResultDescription description;
};

View File

@ -0,0 +1,86 @@
#include "MySQLWireBlockOutputStream.h"
#include <Core/MySQLProtocol.h>
#include <Interpreters/ProcessList.h>
#include <iomanip>
#include <sstream>
namespace DB
{
using namespace MySQLProtocol;
MySQLWireBlockOutputStream::MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context)
: header(header)
, context(context)
, packet_sender(std::make_shared<PacketSender>(buf, context.sequence_id))
{
packet_sender->max_packet_size = context.max_packet_size;
}
void MySQLWireBlockOutputStream::writePrefix()
{
if (header.columns() == 0)
return;
packet_sender->sendPacket(LengthEncodedNumber(header.columns()));
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
{
ColumnDefinition column_definition(column.name, CharacterSet::binary, 0, ColumnType::MYSQL_TYPE_STRING, 0, 0);
packet_sender->sendPacket(column_definition);
}
if (!(context.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOF_Packet(0, 0));
}
}
void MySQLWireBlockOutputStream::write(const Block & block)
{
size_t rows = block.rows();
for (size_t i = 0; i < rows; i++)
{
ResultsetRow row_packet;
for (const ColumnWithTypeAndName & column : block)
{
String column_value;
WriteBufferFromString ostr(column_value);
column.type->serializeAsText(*column.column.get(), i, ostr, format_settings);
ostr.finish();
row_packet.appendColumn(std::move(column_value));
}
packet_sender->sendPacket(row_packet);
}
}
void MySQLWireBlockOutputStream::writeSuffix()
{
QueryStatus * process_list_elem = context.getProcessListElement();
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
size_t affected_rows = info.written_rows;
std::stringstream human_readable_info;
human_readable_info << std::fixed << std::setprecision(3)
<< "Read " << info.read_rows << " rows, " << formatReadableSizeWithBinarySuffix(info.read_bytes) << " in " << info.elapsed_seconds << " sec., "
<< static_cast<size_t>(info.read_rows / info.elapsed_seconds) << " rows/sec., "
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
if (header.columns() == 0)
packet_sender->sendPacket(OK_Packet(0x0, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
else
if (context.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
else
packet_sender->sendPacket(EOF_Packet(0, 0), true);
}
void MySQLWireBlockOutputStream::flush()
{
packet_sender->out->next();
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include <Core/MySQLProtocol.h>
#include <DataStreams/IBlockOutputStream.h>
#include <Formats/FormatFactory.h>
#include <Formats/FormatSettings.h>
#include <Interpreters/Context.h>
namespace DB
{
/** Interface for writing rows in MySQL Client/Server Protocol format.
*/
class MySQLWireBlockOutputStream : public IBlockOutputStream
{
public:
MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context);
Block getHeader() const { return header; }
void write(const Block & block);
void writePrefix();
void writeSuffix();
void flush();
private:
Block header;
Context & context;
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
FormatSettings format_settings;
};
using MySQLWireBlockOutputStreamPtr = std::shared_ptr<MySQLWireBlockOutputStream>;
}

View File

@ -0,0 +1,19 @@
#include <Formats/MySQLWireBlockOutputStream.h>
namespace DB
{
void registerOutputFormatMySQLWire(FormatFactory & factory)
{
factory.registerOutputFormat("MySQLWire", [](
WriteBuffer & buf,
const Block & sample,
const Context & context,
const FormatSettings &)
{
return std::make_shared<MySQLWireBlockOutputStream>(buf, sample, const_cast<Context &>(context));
});
}
}

View File

@ -308,7 +308,7 @@ bool OPTIMIZE(1) TabSeparatedRowInputStream::parseRowAndPrintDiagnosticInfo(
if (curr_position < prev_position)
throw Exception("Logical error: parsing is non-deterministic.", ErrorCodes::LOGICAL_ERROR);
if (isNumber(current_column_type) || isDateOrDateTime(current_column_type))
if (isNativeNumber(current_column_type) || isDateOrDateTime(current_column_type))
{
/// An empty string instead of a value.
if (curr_position == prev_position)

View File

@ -215,10 +215,10 @@ void XMLRowOutputStream::writeStatistics()
writeText(watch.elapsedSeconds(), *ostr);
writeCString("</elapsed>\n", *ostr);
writeCString("\t\t<rows_read>", *ostr);
writeText(progress.rows.load(), *ostr);
writeText(progress.read_rows.load(), *ostr);
writeCString("</rows_read>\n", *ostr);
writeCString("\t\t<bytes_read>", *ostr);
writeText(progress.bytes.load(), *ostr);
writeText(progress.read_bytes.load(), *ostr);
writeCString("</bytes_read>\n", *ostr);
writeCString("\t</statistics>\n", *ostr);
}

View File

@ -263,7 +263,7 @@ public:
+ toString(arguments.size()) + ", should be 2 or 3",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!isNumber(arguments[1].type))
if (!isNativeNumber(arguments[1].type))
throw Exception("Second argument for function " + getName() + " (delta) must be number",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -59,7 +59,7 @@ private:
{
const auto check_argument_type = [this] (const IDataType * arg)
{
if (!isNumber(arg))
if (!isNativeNumber(arg))
throw Exception{"Illegal type " + arg->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
};

View File

@ -56,7 +56,7 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const auto & arg = arguments.front();
if (!isNumber(arg) && !isDecimal(arg))
if (!isNumber(arg))
throw Exception{"Illegal type " + arg->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -37,7 +37,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isNumber(arguments.front()))
if (!isNativeNumber(arguments.front()))
throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataTypeUInt8>();

View File

@ -21,5 +21,7 @@ void registerFunctionsBitmap(FunctionFactory & factory)
factory.registerFunction<FunctionBitmapXor>();
factory.registerFunction<FunctionBitmapAndnot>();
factory.registerFunction<FunctionBitmapHasAll>();
factory.registerFunction<FunctionBitmapHasAny>();
}
}

View File

@ -342,7 +342,27 @@ struct BitmapAndnotCardinalityImpl
}
};
template <template <typename> class Impl, typename Name>
template <typename T>
struct BitmapHasAllImpl
{
using ReturnType = UInt8;
static UInt8 apply(const AggregateFunctionGroupBitmapData<T> & bd1, const AggregateFunctionGroupBitmapData<T> & bd2)
{
return bd1.rbs.rb_is_subset(bd2.rbs);
}
};
template <typename T>
struct BitmapHasAnyImpl
{
using ReturnType = UInt8;
static UInt8 apply(const AggregateFunctionGroupBitmapData<T> & bd1, const AggregateFunctionGroupBitmapData<T> & bd2)
{
return bd1.rbs.rb_intersect(bd2.rbs);
}
};
template <template <typename> class Impl, typename Name, typename ToType>
class FunctionBitmapCardinality : public IFunction
{
public:
@ -369,6 +389,13 @@ public:
throw Exception(
"Second argument for function " + getName() + " must be an bitmap but it has type " + arguments[1]->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (bitmap_type0->getArgumentsDataTypes()[0]->getTypeId() != bitmap_type1->getArgumentsDataTypes()[0]->getTypeId())
throw Exception(
"The nested type in bitmaps must be the same, but one is " + bitmap_type0->getArgumentsDataTypes()[0]->getName()
+ ", and the other is " + bitmap_type1->getArgumentsDataTypes()[0]->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeNumber<ToType>>();
}
@ -398,8 +425,6 @@ public:
}
private:
using ToType = UInt64;
template <typename T>
void executeIntType(
Block & block, const ColumnNumbers & arguments, size_t input_rows_count, typename ColumnVector<ToType>::Container & vec_to)
@ -487,6 +512,13 @@ public:
throw Exception(
"Second argument for function " + getName() + " must be an bitmap but it has type " + arguments[1]->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (bitmap_type0->getArgumentsDataTypes()[0]->getTypeId() != bitmap_type1->getArgumentsDataTypes()[0]->getTypeId())
throw Exception(
"The nested type in bitmaps must be the same, but one is " + bitmap_type0->getArgumentsDataTypes()[0]->getName()
+ ", and the other is " + bitmap_type1->getArgumentsDataTypes()[0]->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return arguments[0];
}
@ -571,12 +603,22 @@ struct NameBitmapAndnotCardinality
{
static constexpr auto name = "bitmapAndnotCardinality";
};
struct NameBitmapHasAll
{
static constexpr auto name = "bitmapHasAll";
};
struct NameBitmapHasAny
{
static constexpr auto name = "bitmapHasAny";
};
using FunctionBitmapSelfCardinality = FunctionBitmapSelfCardinalityImpl<NameBitmapCardinality>;
using FunctionBitmapAndCardinality = FunctionBitmapCardinality<BitmapAndCardinalityImpl, NameBitmapAndCardinality>;
using FunctionBitmapOrCardinality = FunctionBitmapCardinality<BitmapOrCardinalityImpl, NameBitmapOrCardinality>;
using FunctionBitmapXorCardinality = FunctionBitmapCardinality<BitmapXorCardinalityImpl, NameBitmapXorCardinality>;
using FunctionBitmapAndnotCardinality = FunctionBitmapCardinality<BitmapAndnotCardinalityImpl, NameBitmapAndnotCardinality>;
using FunctionBitmapAndCardinality = FunctionBitmapCardinality<BitmapAndCardinalityImpl, NameBitmapAndCardinality, UInt64>;
using FunctionBitmapOrCardinality = FunctionBitmapCardinality<BitmapOrCardinalityImpl, NameBitmapOrCardinality, UInt64>;
using FunctionBitmapXorCardinality = FunctionBitmapCardinality<BitmapXorCardinalityImpl, NameBitmapXorCardinality, UInt64>;
using FunctionBitmapAndnotCardinality = FunctionBitmapCardinality<BitmapAndnotCardinalityImpl, NameBitmapAndnotCardinality, UInt64>;
using FunctionBitmapHasAll = FunctionBitmapCardinality<BitmapHasAllImpl, NameBitmapHasAll, UInt8>;
using FunctionBitmapHasAny = FunctionBitmapCardinality<BitmapHasAnyImpl, NameBitmapHasAny, UInt8>;
struct NameBitmapAnd
{

View File

@ -20,7 +20,7 @@ void throwExceptionForIncompletelyParsedValue(
else
message_buf << " at begin of string";
if (isNumber(to_type))
if (isNativeNumber(to_type))
message_buf << ". Note: there are to" << to_type.getName() << "OrZero and to" << to_type.getName() << "OrNull functions, which returns zero/NULL instead of throwing exception.";
throw Exception(message_buf.str(), ErrorCodes::CANNOT_PARSE_TEXT);

View File

@ -1785,7 +1785,7 @@ private:
return createStringToEnumWrapper<ColumnString, EnumType>();
else if (checkAndGetDataType<DataTypeFixedString>(from_type.get()))
return createStringToEnumWrapper<ColumnFixedString, EnumType>();
else if (isNumber(from_type) || isEnum(from_type))
else if (isNativeNumber(from_type) || isEnum(from_type))
{
auto function = Function::create(context);

View File

@ -1,6 +1,7 @@
#pragma once
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeDate.h>
@ -661,21 +662,77 @@ DECLARE_DICT_GET_TRAITS(UInt32, DataTypeDateTime)
DECLARE_DICT_GET_TRAITS(UInt128, DataTypeUUID)
#undef DECLARE_DICT_GET_TRAITS
template <typename T> struct DictGetTraits<DataTypeDecimal<T>>
{
static constexpr bool is_dec32 = std::is_same_v<T, Decimal32>;
static constexpr bool is_dec64 = std::is_same_v<T, Decimal64>;
static constexpr bool is_dec128 = std::is_same_v<T, Decimal128>;
template <typename DictionaryType>
static void get(const DictionaryType * dict, const std::string & name, const PaddedPODArray<UInt64> & ids,
DecimalPaddedPODArray<T> & out)
{
if constexpr (is_dec32) dict->getDecimal32(name, ids, out);
if constexpr (is_dec64) dict->getDecimal64(name, ids, out);
if constexpr (is_dec128) dict->getDecimal128(name, ids, out);
}
template <typename DictionaryType>
static void get(const DictionaryType * dict, const std::string & name, const Columns & key_columns, const DataTypes & key_types,
DecimalPaddedPODArray<T> & out)
{
if constexpr (is_dec32) dict->getDecimal32(name, key_columns, key_types, out);
if constexpr (is_dec64) dict->getDecimal64(name, key_columns, key_types, out);
if constexpr (is_dec128) dict->getDecimal128(name, key_columns, key_types, out);
}
template <typename DictionaryType>
static void get(const DictionaryType * dict, const std::string & name, const PaddedPODArray<UInt64> & ids,
const PaddedPODArray<Int64> & dates, DecimalPaddedPODArray<T> & out)
{
if constexpr (is_dec32) dict->getDecimal32(name, ids, dates, out);
if constexpr (is_dec64) dict->getDecimal64(name, ids, dates, out);
if constexpr (is_dec128) dict->getDecimal128(name, ids, dates, out);
}
template <typename DictionaryType, typename DefaultsType>
static void getOrDefault(const DictionaryType * dict, const std::string & name, const PaddedPODArray<UInt64> & ids,
const DefaultsType & def, DecimalPaddedPODArray<T> & out)
{
if constexpr (is_dec32) dict->getDecimal32(name, ids, def, out);
if constexpr (is_dec64) dict->getDecimal64(name, ids, def, out);
if constexpr (is_dec128) dict->getDecimal128(name, ids, def, out);
}
template <typename DictionaryType, typename DefaultsType>
static void getOrDefault(const DictionaryType * dict, const std::string & name, const Columns & key_columns,
const DataTypes & key_types, const DefaultsType & def, DecimalPaddedPODArray<T> & out)
{
if constexpr (is_dec32) dict->getDecimal32(name, key_columns, key_types, def, out);
if constexpr (is_dec64) dict->getDecimal64(name, key_columns, key_types, def, out);
if constexpr (is_dec128) dict->getDecimal128(name, key_columns, key_types, def, out);
}
};
template <typename DataType, typename Name>
class FunctionDictGet final : public IFunction
{
using Type = typename DataType::FieldType;
using ColVec = std::conditional_t<IsDecimalNumber<Type>, ColumnDecimal<Type>, ColumnVector<Type>>;
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context & context)
static FunctionPtr create(const Context & context, UInt32 dec_scale = 0)
{
return std::make_shared<FunctionDictGet>(context.getExternalDictionaries());
return std::make_shared<FunctionDictGet>(context.getExternalDictionaries(), dec_scale);
}
FunctionDictGet(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
FunctionDictGet(const ExternalDictionaries & dictionaries, UInt32 dec_scale = 0)
: dictionaries(dictionaries)
, decimal_scale(dec_scale)
{}
String getName() const override { return name; }
@ -719,7 +776,10 @@ private:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
return std::make_shared<DataType>();
if constexpr (IsDataTypeDecimal<DataType>)
return std::make_shared<DataType>(DataType::maxPrecision(), decimal_scale);
else
return std::make_shared<DataType>();
}
bool isDeterministic() const override { return false; }
@ -771,7 +831,11 @@ private:
const auto id_col_untyped = block.getByPosition(arguments[2]).column.get();
if (const auto id_col = checkAndGetColumn<ColumnUInt64>(id_col_untyped))
{
auto out = ColumnVector<Type>::create(id_col->size());
typename ColVec::MutablePtr out;
if constexpr (IsDataTypeDecimal<DataType>)
out = ColVec::create(id_col->size(), decimal_scale);
else
out = ColVec::create(id_col->size());
const auto & ids = id_col->getData();
auto & data = out->getData();
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
@ -780,9 +844,21 @@ private:
else if (const auto id_col_const = checkAndGetColumnConst<ColumnVector<UInt64>>(id_col_untyped))
{
const PaddedPODArray<UInt64> ids(1, id_col_const->getValue<UInt64>());
PaddedPODArray<Type> data(1);
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col_const->size(), toField(data.front()));
if constexpr (IsDataTypeDecimal<DataType>)
{
DecimalPaddedPODArray<Type> data(1, decimal_scale);
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
block.getByPosition(result).column =
DataType(DataType::maxPrecision(), decimal_scale).createColumnConst(
id_col_const->size(), toField(data.front(), decimal_scale));
}
else
{
PaddedPODArray<Type> data(1);
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col_const->size(), toField(data.front()));
}
}
else
throw Exception{"Third argument of function " + getName() + " must be UInt64", ErrorCodes::ILLEGAL_COLUMN};
@ -818,7 +894,11 @@ private:
const auto & key_columns = static_cast<const ColumnTuple &>(*key_col).getColumnsCopy();
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
auto out = ColumnVector<Type>::create(key_columns.front()->size());
typename ColVec::MutablePtr out;
if constexpr (IsDataTypeDecimal<DataType>)
out = ColVec::create(key_columns.front()->size(), decimal_scale);
else
out = ColVec::create(key_columns.front()->size());
auto & data = out->getData();
DictGetTraits<DataType>::get(dict, attr_name, key_columns, key_types, data);
block.getByPosition(result).column = std::move(out);
@ -855,7 +935,11 @@ private:
const auto & id_col_values = getColumnDataAsPaddedPODArray(*id_col_untyped, id_col_values_storage);
const auto & range_col_values = getColumnDataAsPaddedPODArray(*range_col_untyped, range_col_values_storage);
auto out = ColumnVector<Type>::create(id_col_untyped->size());
typename ColVec::MutablePtr out;
if constexpr (IsDataTypeDecimal<DataType>)
out = ColVec::create(id_col_untyped->size(), decimal_scale);
else
out = ColVec::create(id_col_untyped->size());
auto & data = out->getData();
DictGetTraits<DataType>::get(dict, attr_name, id_col_values, range_col_values, data);
block.getByPosition(result).column = std::move(out);
@ -864,6 +948,7 @@ private:
}
const ExternalDictionaries & dictionaries;
UInt32 decimal_scale;
};
struct NameDictGetUInt8 { static constexpr auto name = "dictGetUInt8"; };
@ -879,6 +964,9 @@ struct NameDictGetFloat64 { static constexpr auto name = "dictGetFloat64"; };
struct NameDictGetDate { static constexpr auto name = "dictGetDate"; };
struct NameDictGetDateTime { static constexpr auto name = "dictGetDateTime"; };
struct NameDictGetUUID { static constexpr auto name = "dictGetUUID"; };
struct NameDictGetDecimal32 { static constexpr auto name = "dictGetDecimal32"; };
struct NameDictGetDecimal64 { static constexpr auto name = "dictGetDecimal64"; };
struct NameDictGetDecimal128 { static constexpr auto name = "dictGetDecimal128"; };
using FunctionDictGetUInt8 = FunctionDictGet<DataTypeUInt8, NameDictGetUInt8>;
using FunctionDictGetUInt16 = FunctionDictGet<DataTypeUInt16, NameDictGetUInt16>;
@ -893,22 +981,29 @@ using FunctionDictGetFloat64 = FunctionDictGet<DataTypeFloat64, NameDictGetFloat
using FunctionDictGetDate = FunctionDictGet<DataTypeDate, NameDictGetDate>;
using FunctionDictGetDateTime = FunctionDictGet<DataTypeDateTime, NameDictGetDateTime>;
using FunctionDictGetUUID = FunctionDictGet<DataTypeUUID, NameDictGetUUID>;
using FunctionDictGetDecimal32 = FunctionDictGet<DataTypeDecimal<Decimal32>, NameDictGetDecimal32>;
using FunctionDictGetDecimal64 = FunctionDictGet<DataTypeDecimal<Decimal64>, NameDictGetDecimal64>;
using FunctionDictGetDecimal128 = FunctionDictGet<DataTypeDecimal<Decimal128>, NameDictGetDecimal128>;
template <typename DataType, typename Name>
class FunctionDictGetOrDefault final : public IFunction
{
using Type = typename DataType::FieldType;
using ColVec = std::conditional_t<IsDecimalNumber<Type>, ColumnDecimal<Type>, ColumnVector<Type>>;
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context & context)
static FunctionPtr create(const Context & context, UInt32 dec_scale = 0)
{
return std::make_shared<FunctionDictGetOrDefault>(context.getExternalDictionaries());
return std::make_shared<FunctionDictGetOrDefault>(context.getExternalDictionaries(), dec_scale);
}
FunctionDictGetOrDefault(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
FunctionDictGetOrDefault(const ExternalDictionaries & dictionaries, UInt32 dec_scale = 0)
: dictionaries(dictionaries)
, decimal_scale(dec_scale)
{}
String getName() const override { return name; }
@ -935,9 +1030,12 @@ private:
if (!checkAndGetDataType<DataType>(arguments[3].get()))
throw Exception{"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName()
+ ", must be " + String(DataType{}.getFamilyName()) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
+ ", must be " + TypeName<Type>::get() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataType>();
if constexpr (IsDataTypeDecimal<DataType>)
return std::make_shared<DataType>(DataType::maxPrecision(), decimal_scale);
else
return std::make_shared<DataType>();
}
bool isDeterministic() const override { return false; }
@ -999,20 +1097,28 @@ private:
{
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = checkAndGetColumn<ColumnVector<Type>>(default_col_untyped))
if (const auto default_col = checkAndGetColumn<ColVec>(default_col_untyped))
{
/// vector ids, vector defaults
auto out = ColumnVector<Type>::create(id_col->size());
typename ColVec::MutablePtr out;
if constexpr (IsDataTypeDecimal<DataType>)
out = ColVec::create(id_col->size(), decimal_scale);
else
out = ColVec::create(id_col->size());
const auto & ids = id_col->getData();
auto & data = out->getData();
const auto & defs = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, defs, data);
block.getByPosition(result).column = std::move(out);
}
else if (const auto default_col_const = checkAndGetColumnConst<ColumnVector<Type>>(default_col_untyped))
else if (const auto default_col_const = checkAndGetColumnConst<ColVec>(default_col_untyped))
{
/// vector ids, const defaults
auto out = ColumnVector<Type>::create(id_col->size());
typename ColVec::MutablePtr out;
if constexpr (IsDataTypeDecimal<DataType>)
out = ColVec::create(id_col->size(), decimal_scale);
else
out = ColVec::create(id_col->size());
const auto & ids = id_col->getData();
auto & data = out->getData();
const auto def = default_col_const->template getValue<Type>();
@ -1020,7 +1126,7 @@ private:
block.getByPosition(result).column = std::move(out);
}
else
throw Exception{"Fourth argument of function " + getName() + " must be " + String(DataType{}.getFamilyName()), ErrorCodes::ILLEGAL_COLUMN};
throw Exception{"Fourth argument of function " + getName() + " must be " + TypeName<Type>::get(), ErrorCodes::ILLEGAL_COLUMN};
}
template <typename DictionaryType>
@ -1030,7 +1136,7 @@ private:
{
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = checkAndGetColumn<ColumnVector<Type>>(default_col_untyped))
if (const auto default_col = checkAndGetColumn<ColVec>(default_col_untyped))
{
/// const ids, vector defaults
const PaddedPODArray<UInt64> ids(1, id_col->getValue<UInt64>());
@ -1038,24 +1144,48 @@ private:
dictionary->has(ids, flags);
if (flags.front())
{
PaddedPODArray<Type> data(1);
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, Type(), data);
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col->size(), toField(data.front()));
if constexpr (IsDataTypeDecimal<DataType>)
{
DecimalPaddedPODArray<Type> data(1, decimal_scale);
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, Type(), data);
block.getByPosition(result).column =
DataType(DataType::maxPrecision(), decimal_scale).createColumnConst(
id_col->size(), toField(data.front(), decimal_scale));
}
else
{
PaddedPODArray<Type> data(1);
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, Type(), data);
block.getByPosition(result).column = DataType().createColumnConst(id_col->size(), toField(data.front()));
}
}
else
block.getByPosition(result).column = block.getByPosition(arguments[3]).column; // reuse the default column
}
else if (const auto default_col_const = checkAndGetColumnConst<ColumnVector<Type>>(default_col_untyped))
else if (const auto default_col_const = checkAndGetColumnConst<ColVec>(default_col_untyped))
{
/// const ids, const defaults
const PaddedPODArray<UInt64> ids(1, id_col->getValue<UInt64>());
PaddedPODArray<Type> data(1);
const auto & def = default_col_const->template getValue<Type>();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col->size(), toField(data.front()));
if constexpr (IsDataTypeDecimal<DataType>)
{
DecimalPaddedPODArray<Type> data(1, decimal_scale);
const auto & def = default_col_const->template getValue<Type>();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
block.getByPosition(result).column =
DataType(DataType::maxPrecision(), decimal_scale).createColumnConst(
id_col->size(), toField(data.front(), decimal_scale));
}
else
{
PaddedPODArray<Type> data(1);
const auto & def = default_col_const->template getValue<Type>();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
block.getByPosition(result).column = DataType().createColumnConst(id_col->size(), toField(data.front()));
}
}
else
throw Exception{"Fourth argument of function " + getName() + " must be " + String(DataType{}.getFamilyName()), ErrorCodes::ILLEGAL_COLUMN};
throw Exception{"Fourth argument of function " + getName() + " must be " + TypeName<Type>::get(), ErrorCodes::ILLEGAL_COLUMN};
}
template <typename DictionaryType>
@ -1082,31 +1212,36 @@ private:
/// @todo detect when all key columns are constant
const auto rows = key_col->size();
auto out = ColumnVector<Type>::create(rows);
typename ColVec::MutablePtr out;
if constexpr (IsDataTypeDecimal<DataType>)
out = ColVec::create(rows, decimal_scale);
else
out = ColVec::create(rows);
auto & data = out->getData();
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = checkAndGetColumn<ColumnVector<Type>>(default_col_untyped))
if (const auto default_col = checkAndGetColumn<ColVec>(default_col_untyped))
{
/// const defaults
const auto & defs = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dict, attr_name, key_columns, key_types, defs, data);
}
else if (const auto default_col_const = checkAndGetColumnConst<ColumnVector<Type>>(default_col_untyped))
else if (const auto default_col_const = checkAndGetColumnConst<ColVec>(default_col_untyped))
{
const auto def = default_col_const->template getValue<Type>();
DictGetTraits<DataType>::getOrDefault(dict, attr_name, key_columns, key_types, def, data);
}
else
throw Exception{"Fourth argument of function " + getName() + " must be " + String(DataType{}.getFamilyName()), ErrorCodes::ILLEGAL_COLUMN};
throw Exception{"Fourth argument of function " + getName() + " must be " + TypeName<Type>::get(), ErrorCodes::ILLEGAL_COLUMN};
block.getByPosition(result).column = std::move(out);
return true;
}
const ExternalDictionaries & dictionaries;
UInt32 decimal_scale;
};
struct NameDictGetUInt8OrDefault { static constexpr auto name = "dictGetUInt8OrDefault"; };
@ -1122,6 +1257,9 @@ struct NameDictGetFloat64OrDefault { static constexpr auto name = "dictGetFloat6
struct NameDictGetDateOrDefault { static constexpr auto name = "dictGetDateOrDefault"; };
struct NameDictGetDateTimeOrDefault { static constexpr auto name = "dictGetDateTimeOrDefault"; };
struct NameDictGetUUIDOrDefault { static constexpr auto name = "dictGetUUIDOrDefault"; };
struct NameDictGetDecimal32OrDefault { static constexpr auto name = "dictGetDecimal32OrDefault"; };
struct NameDictGetDecimal64OrDefault { static constexpr auto name = "dictGetDecimal64OrDefault"; };
struct NameDictGetDecimal128OrDefault { static constexpr auto name = "dictGetDecimal128OrDefault"; };
using FunctionDictGetUInt8OrDefault = FunctionDictGetOrDefault<DataTypeUInt8, NameDictGetUInt8OrDefault>;
using FunctionDictGetUInt16OrDefault = FunctionDictGetOrDefault<DataTypeUInt16, NameDictGetUInt16OrDefault>;
@ -1136,21 +1274,10 @@ using FunctionDictGetFloat64OrDefault = FunctionDictGetOrDefault<DataTypeFloat64
using FunctionDictGetDateOrDefault = FunctionDictGetOrDefault<DataTypeDate, NameDictGetDateOrDefault>;
using FunctionDictGetDateTimeOrDefault = FunctionDictGetOrDefault<DataTypeDateTime, NameDictGetDateTimeOrDefault>;
using FunctionDictGetUUIDOrDefault = FunctionDictGetOrDefault<DataTypeUUID, NameDictGetUUIDOrDefault>;
using FunctionDictGetDecimal32OrDefault = FunctionDictGetOrDefault<DataTypeDecimal<Decimal32>, NameDictGetDecimal32OrDefault>;
using FunctionDictGetDecimal64OrDefault = FunctionDictGetOrDefault<DataTypeDecimal<Decimal64>, NameDictGetDecimal64OrDefault>;
using FunctionDictGetDecimal128OrDefault = FunctionDictGetOrDefault<DataTypeDecimal<Decimal128>, NameDictGetDecimal128OrDefault>;
#define FOR_DICT_TYPES(M) \
M(UInt8) \
M(UInt16) \
M(UInt32) \
M(UInt64) \
M(Int8) \
M(Int16) \
M(Int32) \
M(Int64) \
M(Float32) \
M(Float64) \
M(Date) \
M(DateTime) \
M(UUID)
/// This variant of function derives the result type automatically.
class FunctionDictGetNoType final : public IFunction
@ -1225,15 +1352,63 @@ private:
if (attribute.name == attr_name)
{
WhichDataType dt = attribute.type;
if (dt.idx == TypeIndex::String)
impl = FunctionDictGetString::create(context);
#define DISPATCH(TYPE) \
else if (dt.idx == TypeIndex::TYPE) \
impl = FunctionDictGet<DataType##TYPE, NameDictGet##TYPE>::create(context);
FOR_DICT_TYPES(DISPATCH)
#undef DISPATCH
else
throw Exception("Unknown dictGet type", ErrorCodes::UNKNOWN_TYPE);
switch (dt.idx)
{
case TypeIndex::String:
case TypeIndex::FixedString:
impl = FunctionDictGetString::create(context);
break;
case TypeIndex::UInt8:
impl = FunctionDictGetUInt8::create(context);
break;
case TypeIndex::UInt16:
impl = FunctionDictGetUInt16::create(context);
break;
case TypeIndex::UInt32:
impl = FunctionDictGetUInt32::create(context);
break;
case TypeIndex::UInt64:
impl = FunctionDictGetUInt64::create(context);
break;
case TypeIndex::Int8:
impl = FunctionDictGetInt8::create(context);
break;
case TypeIndex::Int16:
impl = FunctionDictGetInt16::create(context);
break;
case TypeIndex::Int32:
impl = FunctionDictGetInt32::create(context);
break;
case TypeIndex::Int64:
impl = FunctionDictGetInt64::create(context);
break;
case TypeIndex::Float32:
impl = FunctionDictGetFloat32::create(context);
break;
case TypeIndex::Float64:
impl = FunctionDictGetFloat64::create(context);
break;
case TypeIndex::Date:
impl = FunctionDictGetDate::create(context);
break;
case TypeIndex::DateTime:
impl = FunctionDictGetDateTime::create(context);
break;
case TypeIndex::UUID:
impl = FunctionDictGetUUID::create(context);
break;
case TypeIndex::Decimal32:
impl = FunctionDictGetDecimal32::create(context, getDecimalScale(*attribute.type));
break;
case TypeIndex::Decimal64:
impl = FunctionDictGetDecimal64::create(context, getDecimalScale(*attribute.type));
break;
case TypeIndex::Decimal128:
impl = FunctionDictGetDecimal128::create(context, getDecimalScale(*attribute.type));
break;
default:
throw Exception("Unknown dictGet type", ErrorCodes::UNKNOWN_TYPE);
}
return attribute.type;
}
}
@ -1312,26 +1487,70 @@ private:
const DictionaryAttribute & attribute = structure.attributes[idx];
if (attribute.name == attr_name)
{
auto arg_type = arguments[3].type;
WhichDataType dt = attribute.type;
if (dt.idx == TypeIndex::String)
if ((arg_type->getTypeId() != dt.idx) || (dt.isStringOrFixedString() && !isString(arg_type)))
throw Exception{"Illegal type " + arg_type->getName() + " of fourth argument of function " + getName() +
", must be " + getTypeName(dt.idx) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
switch (dt.idx)
{
if (!isString(arguments[3].type))
throw Exception{"Illegal type " + arguments[3].type->getName() + " of fourth argument of function " + getName() +
", must be String.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
impl = FunctionDictGetStringOrDefault::create(context);
case TypeIndex::String:
impl = FunctionDictGetStringOrDefault::create(context);
break;
case TypeIndex::UInt8:
impl = FunctionDictGetUInt8OrDefault::create(context);
break;
case TypeIndex::UInt16:
impl = FunctionDictGetUInt16OrDefault::create(context);
break;
case TypeIndex::UInt32:
impl = FunctionDictGetUInt32OrDefault::create(context);
break;
case TypeIndex::UInt64:
impl = FunctionDictGetUInt64OrDefault::create(context);
break;
case TypeIndex::Int8:
impl = FunctionDictGetInt8OrDefault::create(context);
break;
case TypeIndex::Int16:
impl = FunctionDictGetInt16OrDefault::create(context);
break;
case TypeIndex::Int32:
impl = FunctionDictGetInt32OrDefault::create(context);
break;
case TypeIndex::Int64:
impl = FunctionDictGetInt64OrDefault::create(context);
break;
case TypeIndex::Float32:
impl = FunctionDictGetFloat32OrDefault::create(context);
break;
case TypeIndex::Float64:
impl = FunctionDictGetFloat64OrDefault::create(context);
break;
case TypeIndex::Date:
impl = FunctionDictGetDateOrDefault::create(context);
break;
case TypeIndex::DateTime:
impl = FunctionDictGetDateTimeOrDefault::create(context);
break;
case TypeIndex::UUID:
impl = FunctionDictGetUUIDOrDefault::create(context);
break;
case TypeIndex::Decimal32:
impl = FunctionDictGetDecimal32OrDefault::create(context, getDecimalScale(*attribute.type));
break;
case TypeIndex::Decimal64:
impl = FunctionDictGetDecimal64OrDefault::create(context, getDecimalScale(*attribute.type));
break;
case TypeIndex::Decimal128:
impl = FunctionDictGetDecimal128OrDefault::create(context, getDecimalScale(*attribute.type));
break;
default:
throw Exception("Unknown dictGetOrDefault type", ErrorCodes::UNKNOWN_TYPE);
}
#define DISPATCH(TYPE) \
else if (dt.idx == TypeIndex::TYPE) \
{ \
if (!checkAndGetDataType<DataType##TYPE>(arguments[3].type.get())) \
throw Exception{"Illegal type " + arguments[3].type->getName() + " of fourth argument of function " + getName() \
+ ", must be " + String(DataType##TYPE{}.getFamilyName()) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; \
impl = FunctionDictGetOrDefault<DataType##TYPE, NameDictGet##TYPE ## OrDefault>::create(context); \
}
FOR_DICT_TYPES(DISPATCH)
#undef DISPATCH
else
throw Exception("Unknown dictGetOrDefault type", ErrorCodes::UNKNOWN_TYPE);
return attribute.type;
}
}

View File

@ -111,7 +111,7 @@ public:
const auto type_x = arguments[0];
if (!isNumber(type_x))
if (!isNativeNumber(type_x))
throw Exception{"Unsupported type " + type_x->getName() + " of first argument of function " + getName() + " must be a numeric type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -143,7 +143,7 @@ public:
{
const IDataType & type = *arguments[0];
if (!isNumber(type))
if (!isNativeNumber(type))
throw Exception("Cannot format " + type.getName() + " as size in bytes", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeString>();

View File

@ -138,7 +138,7 @@ public:
for (auto j : ext::range(0, elements.size()))
{
if (!isNumber(elements[j]))
if (!isNativeNumber(elements[j]))
{
throw Exception(getMsgPrefix(i) + " must contains numeric tuple at position " + toString(j + 1),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -120,11 +120,22 @@ private:
/// prepare() does Impl-specific preparation before handling each row.
impl.prepare(Name::name, block, arguments, result_pos);
bool json_parsed_ok = false;
if (col_json_const)
{
StringRef json{reinterpret_cast<const char *>(&chars[0]), offsets[0] - 1};
json_parsed_ok = parser.parse(json);
}
for (const auto i : ext::range(0, input_rows_count))
{
StringRef json{reinterpret_cast<const char *>(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1};
bool ok = parser.parse(json);
if (!col_json_const)
{
StringRef json{reinterpret_cast<const char *>(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1};
json_parsed_ok = parser.parse(json);
}
bool ok = json_parsed_ok;
if (ok)
{
auto it = parser.getRoot();

View File

@ -309,8 +309,8 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < arguments.size(); ++i)
if (!(isNumber(arguments[i])
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || isNumber(removeNullable(arguments[i]))))))
if (!(isNativeNumber(arguments[i])
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || isNativeNumber(removeNullable(arguments[i]))))))
throw Exception("Illegal type ("
+ arguments[i]->getName()
+ ") of " + toString(i + 1) + " argument of function " + getName(),
@ -488,7 +488,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isNumber(arguments[0]))
if (!isNativeNumber(arguments[0]))
throw Exception("Illegal type ("
+ arguments[0]->getName()
+ ") of argument of function " + getName(),

View File

@ -500,7 +500,7 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (const auto & type : arguments)
if (!isNumber(type) && !isDecimal(type))
if (!isNumber(type))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -588,7 +588,7 @@ public:
{
const DataTypePtr & type_x = arguments[0];
if (!(isNumber(type_x) || isDecimal(type_x)))
if (!isNumber(type_x))
throw Exception{"Unsupported type " + type_x->getName()
+ " of first argument of function " + getName()
+ ", must be numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -601,7 +601,7 @@ public:
const auto type_arr_nested = type_arr->getNestedType();
if (!(isNumber(type_arr_nested) || isDecimal(type_arr_nested)))
if (!isNumber(type_arr_nested))
{
throw Exception{"Elements of array of second argument of function " + getName()
+ " must be numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -32,7 +32,7 @@ namespace DB
* calculation. If the right string size is big (more than 2**15 bytes),
* the strings are not similar at all and we return 1.
*/
template <size_t N, class CodePoint, bool UTF8, bool CaseInsensitive>
template <size_t N, class CodePoint, bool UTF8, bool CaseInsensitive, bool Symmetric>
struct NgramDistanceImpl
{
using ResultType = Float32;
@ -138,6 +138,7 @@ struct NgramDistanceImpl
}
/// This is not a really true case insensitive utf8. We zero the 5-th bit of every byte.
/// And first bit of first byte if there are two bytes.
/// For ASCII it works https://catonmat.net/ascii-case-conversion-trick. For most cyrrilic letters also does.
/// For others, we don't care now. Lowering UTF is not a cheap operation.
if constexpr (CaseInsensitive)
@ -151,6 +152,7 @@ struct NgramDistanceImpl
res &= ~(1u << (5 + 2 * CHAR_BIT));
[[fallthrough]];
case 2:
res &= ~(1u);
res &= ~(1u << (5 + CHAR_BIT));
[[fallthrough]];
default:
@ -222,9 +224,10 @@ struct NgramDistanceImpl
for (; iter + N <= found; ++iter)
{
UInt16 hash = hash_functor(cp + iter);
/// For symmetric version we should add when we can't subtract to get symmetric difference.
if (static_cast<Int16>(ngram_stats[hash]) > 0)
--distance;
else
else if constexpr (Symmetric)
++distance;
if constexpr (ReuseStats)
ngram_storage[ngram_cnt] = hash;
@ -267,7 +270,8 @@ struct NgramDistanceImpl
if (data_size <= max_string_size)
{
size_t first_size = dispatchSearcher(calculateHaystackStatsAndMetric<false>, data.data(), data_size, common_stats, distance, nullptr);
res = distance * 1.f / std::max(first_size + second_size, size_t(1));
/// For !Symmetric version we should not use first_size.
res = distance * 1.f / std::max(Symmetric * first_size + second_size, size_t(1));
}
else
{
@ -326,7 +330,10 @@ struct NgramDistanceImpl
--common_stats[needle_ngram_storage[j]];
/// For now, common stats is a zero array.
res[i] = distance * 1.f / std::max(haystack_stats_size + needle_stats_size, size_t(1));
/// For !Symmetric version we should not use haystack_stats_size.
res[i] = distance * 1.f / std::max(Symmetric * haystack_stats_size + needle_stats_size, size_t(1));
}
else
{
@ -340,6 +347,71 @@ struct NgramDistanceImpl
}
}
static void constant_vector(
std::string haystack,
const ColumnString::Chars & needle_data,
const ColumnString::Offsets & needle_offsets,
PaddedPODArray<Float32> & res)
{
/// For symmetric version it is better to use vector_constant
if constexpr (Symmetric)
{
vector_constant(needle_data, needle_offsets, std::move(haystack), res);
}
else
{
const size_t haystack_size = haystack.size();
haystack.resize(haystack_size + default_padding);
/// For logic explanation see vector_vector function.
const size_t needle_offsets_size = needle_offsets.size();
size_t prev_offset = 0;
NgramStats common_stats = {};
std::unique_ptr<UInt16[]> needle_ngram_storage(new UInt16[max_string_size]);
std::unique_ptr<UInt16[]> haystack_ngram_storage(new UInt16[max_string_size]);
for (size_t i = 0; i < needle_offsets_size; ++i)
{
const char * needle = reinterpret_cast<const char *>(&needle_data[prev_offset]);
const size_t needle_size = needle_offsets[i] - prev_offset - 1;
if (needle_size <= max_string_size && haystack_size <= max_string_size)
{
const size_t needle_stats_size = dispatchSearcher(
calculateNeedleStats<true>,
needle,
needle_size,
common_stats,
needle_ngram_storage.get());
size_t distance = needle_stats_size;
dispatchSearcher(
calculateHaystackStatsAndMetric<true>,
haystack.data(),
haystack_size,
common_stats,
distance,
haystack_ngram_storage.get());
for (size_t j = 0; j < needle_stats_size; ++j)
--common_stats[needle_ngram_storage[j]];
res[i] = distance * 1.f / std::max(needle_stats_size, size_t(1));
}
else
{
res[i] = 1.f;
}
prev_offset = needle_offsets[i];
}
}
}
static void vector_constant(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
@ -373,7 +445,8 @@ struct NgramDistanceImpl
haystack_size, common_stats,
distance,
ngram_storage.get());
res[i] = distance * 1.f / std::max(haystack_stats_size + needle_stats_size, size_t(1));
/// For !Symmetric version we should not use haystack_stats_size.
res[i] = distance * 1.f / std::max(Symmetric * haystack_stats_size + needle_stats_size, size_t(1));
}
else
{
@ -391,7 +464,6 @@ struct NameNgramDistance
{
static constexpr auto name = "ngramDistance";
};
struct NameNgramDistanceCaseInsensitive
{
static constexpr auto name = "ngramDistanceCaseInsensitive";
@ -407,10 +479,34 @@ struct NameNgramDistanceUTF8CaseInsensitive
static constexpr auto name = "ngramDistanceCaseInsensitiveUTF8";
};
using FunctionNgramDistance = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false>, NameNgramDistance>;
using FunctionNgramDistanceCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true>, NameNgramDistanceCaseInsensitive>;
using FunctionNgramDistanceUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false>, NameNgramDistanceUTF8>;
using FunctionNgramDistanceCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true>, NameNgramDistanceUTF8CaseInsensitive>;
struct NameNgramSearch
{
static constexpr auto name = "ngramSearch";
};
struct NameNgramSearchCaseInsensitive
{
static constexpr auto name = "ngramSearchCaseInsensitive";
};
struct NameNgramSearchUTF8
{
static constexpr auto name = "ngramSearchUTF8";
};
struct NameNgramSearchUTF8CaseInsensitive
{
static constexpr auto name = "ngramSearchCaseInsensitiveUTF8";
};
using FunctionNgramDistance = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false, true>, NameNgramDistance>;
using FunctionNgramDistanceCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true, true>, NameNgramDistanceCaseInsensitive>;
using FunctionNgramDistanceUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false, true>, NameNgramDistanceUTF8>;
using FunctionNgramDistanceCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true, true>, NameNgramDistanceUTF8CaseInsensitive>;
using FunctionNgramSearch = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false, false>, NameNgramSearch>;
using FunctionNgramSearchCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true, false>, NameNgramSearchCaseInsensitive>;
using FunctionNgramSearchUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false, false>, NameNgramSearchUTF8>;
using FunctionNgramSearchCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true, false>, NameNgramSearchUTF8CaseInsensitive>;
void registerFunctionsStringSimilarity(FunctionFactory & factory)
{
@ -418,6 +514,11 @@ void registerFunctionsStringSimilarity(FunctionFactory & factory)
factory.registerFunction<FunctionNgramDistanceCaseInsensitive>();
factory.registerFunction<FunctionNgramDistanceUTF8>();
factory.registerFunction<FunctionNgramDistanceCaseInsensitiveUTF8>();
factory.registerFunction<FunctionNgramSearch>();
factory.registerFunction<FunctionNgramSearchCaseInsensitive>();
factory.registerFunction<FunctionNgramSearchUTF8>();
factory.registerFunction<FunctionNgramSearchCaseInsensitiveUTF8>();
}
}

View File

@ -110,15 +110,15 @@ public:
}
else if (col_haystack_const && col_needle_vector)
{
const String & needle = col_haystack_const->getValue<String>();
if (needle.size() > Impl::max_string_size)
const String & haystack = col_haystack_const->getValue<String>();
if (haystack.size() > Impl::max_string_size)
{
throw Exception(
"String size of needle is too big for function " + getName() + ". Should be at most "
"String size of haystack is too big for function " + getName() + ". Should be at most "
+ std::to_string(Impl::max_string_size),
ErrorCodes::TOO_LARGE_STRING_SIZE);
}
Impl::vector_constant(col_needle_vector->getChars(), col_needle_vector->getOffsets(), needle, vec_res);
Impl::constant_vector(haystack, col_needle_vector->getChars(), col_needle_vector->getOffsets(), vec_res);
}
else
{

View File

@ -595,7 +595,7 @@ inline bool allowArrayIndex(const DataTypePtr & type0, const DataTypePtr & type1
DataTypePtr data_type0 = removeNullable(type0);
DataTypePtr data_type1 = removeNullable(type1);
return ((isNumber(data_type0) || isEnum(data_type0)) && isNumber(data_type1))
return ((isNativeNumber(data_type0) || isEnum(data_type0)) && isNativeNumber(data_type1))
|| data_type0->equals(*data_type1);
}

View File

@ -183,7 +183,7 @@ Columns FunctionArrayIntersect::castColumns(
auto & type_nested = type_array->getNestedType();
auto type_not_nullable_nested = removeNullable(type_nested);
const bool is_numeric_or_string = isNumber(type_not_nullable_nested)
const bool is_numeric_or_string = isNativeNumber(type_not_nullable_nested)
|| isDateOrDateTime(type_not_nullable_nested)
|| isStringOrFixedString(type_not_nullable_nested);

View File

@ -37,7 +37,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isNumber(arguments[0]))
if (!isNativeNumber(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() +
" of argument of function " + getName() +
", expected Integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -55,8 +55,8 @@ public:
+ ".",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!isNumber(arguments[0]) || !isNumber(arguments[1]) || !isNumber(arguments[2])
|| (arguments.size() == 4 && !isNumber(arguments[3])))
if (!isNativeNumber(arguments[0]) || !isNativeNumber(arguments[1]) || !isNativeNumber(arguments[2])
|| (arguments.size() == 4 && !isNativeNumber(arguments[3])))
throw Exception("All arguments for function " + getName() + " must be numeric.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeString>();

View File

@ -2,6 +2,7 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/typeid_cast.h>
@ -60,7 +61,7 @@ public:
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type->getReturnType();
return type->getReturnTypeToPredict();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override

Some files were not shown because too many files have changed in this diff Show More