mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-11 17:02:25 +00:00
Merge pull request #40897 from ClickHouse/catboost-bridge-resurrected
Move CatBoost evaluation into clickhouse-library-bridge
This commit is contained in:
commit
b32b02d844
4
.gitignore
vendored
4
.gitignore
vendored
@ -58,6 +58,10 @@ cmake_install.cmake
|
||||
CTestTestfile.cmake
|
||||
*.a
|
||||
*.o
|
||||
*.so
|
||||
*.dll
|
||||
*.lib
|
||||
*.dylib
|
||||
cmake-build-*
|
||||
|
||||
# Python cache
|
||||
|
@ -1823,6 +1823,36 @@ Result:
|
||||
Evaluate external model.
|
||||
Accepts a model name and model arguments. Returns Float64.
|
||||
|
||||
## catboostEvaluate(path_to_model, feature_1, feature_2, …, feature_n)
|
||||
|
||||
Evaluate external catboost model. [CatBoost](https://catboost.ai) is an open-source gradient boosting library developed by Yandex for machine learing.
|
||||
Accepts a path to a catboost model and model arguments (features). Returns Float64.
|
||||
|
||||
``` sql
|
||||
SELECT feat1, ..., feat_n, catboostEvaluate('/path/to/model.bin', feat_1, ..., feat_n) AS prediction
|
||||
FROM data_table
|
||||
```
|
||||
|
||||
**Prerequisites**
|
||||
|
||||
1. Build the catboost evaluation library
|
||||
|
||||
Before evaluating catboost models, the `libcatboostmodel.<so|dylib>` library must be made available. See [CatBoost documentation](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html) how to compile it.
|
||||
|
||||
Next, specify the path to `libcatboostmodel.<so|dylib>` in the clickhouse configuration:
|
||||
|
||||
``` xml
|
||||
<clickhouse>
|
||||
...
|
||||
<catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path>
|
||||
...
|
||||
</clickhouse>
|
||||
```
|
||||
|
||||
2. Train a catboost model using libcatboost
|
||||
|
||||
See [Training and applying models](https://catboost.ai/docs/features/training.html#training) for how to train catboost models from a training data set.
|
||||
|
||||
## throwIf(x\[, message\[, error_code\]\])
|
||||
|
||||
Throw an exception if the argument is non zero.
|
||||
|
@ -30,7 +30,12 @@ SELECT name, status FROM system.dictionaries;
|
||||
|
||||
## RELOAD MODELS
|
||||
|
||||
Reloads all [CatBoost](../../guides/developer/apply-catboost-model.md) models if the configuration was updated without restarting the server.
|
||||
:::note
|
||||
This statement and `SYSTEM RELOAD MODEL` merely unload catboost models from the clickhouse-library-bridge. The function `catboostEvaluate()`
|
||||
loads a model upon first access if it is not loaded yet.
|
||||
:::
|
||||
|
||||
Unloads all CatBoost models.
|
||||
|
||||
**Syntax**
|
||||
|
||||
@ -40,12 +45,12 @@ SYSTEM RELOAD MODELS [ON CLUSTER cluster_name]
|
||||
|
||||
## RELOAD MODEL
|
||||
|
||||
Completely reloads a CatBoost model `model_name` if the configuration was updated without restarting the server.
|
||||
Unloads a CatBoost model at `model_path`.
|
||||
|
||||
**Syntax**
|
||||
|
||||
```sql
|
||||
SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] <model_name>
|
||||
SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] <model_path>
|
||||
```
|
||||
|
||||
## RELOAD FUNCTIONS
|
||||
|
@ -155,7 +155,6 @@ getting_started/index.md getting-started/index.md
|
||||
getting_started/install.md getting-started/install.md
|
||||
getting_started/playground.md getting-started/playground.md
|
||||
getting_started/tutorial.md getting-started/tutorial.md
|
||||
guides/apply_catboost_model.md guides/apply-catboost-model.md
|
||||
images/column_oriented.gif images/column-oriented.gif
|
||||
images/row_oriented.gif images/row-oriented.gif
|
||||
interfaces/http_interface.md interfaces/http.md
|
||||
|
@ -1,241 +0,0 @@
|
||||
---
|
||||
slug: /ru/guides/apply-catboost-model
|
||||
sidebar_position: 41
|
||||
sidebar_label: "Применение модели CatBoost в ClickHouse"
|
||||
---
|
||||
|
||||
# Применение модели CatBoost в ClickHouse {#applying-catboost-model-in-clickhouse}
|
||||
|
||||
[CatBoost](https://catboost.ai) — открытая программная библиотека разработанная компанией [Яндекс](https://yandex.ru/company/) для машинного обучения, которая использует схему градиентного бустинга.
|
||||
|
||||
С помощью этой инструкции вы научитесь применять предобученные модели в ClickHouse: в результате вы запустите вывод модели из SQL.
|
||||
|
||||
Чтобы применить модель CatBoost в ClickHouse:
|
||||
|
||||
1. [Создайте таблицу](#create-table).
|
||||
2. [Вставьте данные в таблицу](#insert-data-to-table).
|
||||
3. [Интегрируйте CatBoost в ClickHouse](#integrate-catboost-into-clickhouse) (Опциональный шаг).
|
||||
4. [Запустите вывод модели из SQL](#run-model-inference).
|
||||
|
||||
Подробнее об обучении моделей в CatBoost, см. [Обучение и применение моделей](https://catboost.ai/docs/features/training.html#training).
|
||||
|
||||
Вы можете перегрузить модели CatBoost, если их конфигурация была обновлена, без перезагрузки сервера. Для этого используйте системные запросы [RELOAD MODEL](../sql-reference/statements/system.md#query_language-system-reload-model) и [RELOAD MODELS](../sql-reference/statements/system.md#query_language-system-reload-models).
|
||||
|
||||
## Перед началом работы {#prerequisites}
|
||||
|
||||
Если у вас еще нет [Docker](https://docs.docker.com/install/), установите его.
|
||||
|
||||
:::note "Примечание"
|
||||
[Docker](https://www.docker.com) – это программная платформа для создания контейнеров, которые изолируют установку CatBoost и ClickHouse от остальной части системы.
|
||||
:::
|
||||
Перед применением модели CatBoost:
|
||||
|
||||
**1.** Скачайте [Docker-образ](https://hub.docker.com/r/yandex/tutorial-catboost-clickhouse) из реестра:
|
||||
|
||||
``` bash
|
||||
$ docker pull yandex/tutorial-catboost-clickhouse
|
||||
```
|
||||
|
||||
Данный Docker-образ содержит все необходимое для запуска CatBoost и ClickHouse: код, среду выполнения, библиотеки, переменные окружения и файлы конфигурации.
|
||||
|
||||
**2.** Проверьте, что Docker-образ успешно скачался:
|
||||
|
||||
``` bash
|
||||
$ docker image ls
|
||||
REPOSITORY TAG IMAGE ID CREATED SIZE
|
||||
yandex/tutorial-catboost-clickhouse latest 622e4d17945b 22 hours ago 1.37GB
|
||||
```
|
||||
|
||||
**3.** Запустите Docker-контейнер основанный на данном образе:
|
||||
|
||||
``` bash
|
||||
$ docker run -it -p 8888:8888 yandex/tutorial-catboost-clickhouse
|
||||
```
|
||||
|
||||
## 1. Создайте таблицу {#create-table}
|
||||
|
||||
Чтобы создать таблицу для обучающей выборки:
|
||||
|
||||
**1.** Запустите клиент ClickHouse:
|
||||
|
||||
``` bash
|
||||
$ clickhouse client
|
||||
```
|
||||
|
||||
:::note "Примечание"
|
||||
Сервер ClickHouse уже запущен внутри Docker-контейнера.
|
||||
:::
|
||||
**2.** Создайте таблицу в ClickHouse с помощью следующей команды:
|
||||
|
||||
``` sql
|
||||
:) CREATE TABLE amazon_train
|
||||
(
|
||||
date Date MATERIALIZED today(),
|
||||
ACTION UInt8,
|
||||
RESOURCE UInt32,
|
||||
MGR_ID UInt32,
|
||||
ROLE_ROLLUP_1 UInt32,
|
||||
ROLE_ROLLUP_2 UInt32,
|
||||
ROLE_DEPTNAME UInt32,
|
||||
ROLE_TITLE UInt32,
|
||||
ROLE_FAMILY_DESC UInt32,
|
||||
ROLE_FAMILY UInt32,
|
||||
ROLE_CODE UInt32
|
||||
)
|
||||
ENGINE = MergeTree ORDER BY date
|
||||
```
|
||||
|
||||
**3.** Выйдите из клиента ClickHouse:
|
||||
|
||||
``` sql
|
||||
:) exit
|
||||
```
|
||||
|
||||
## 2. Вставьте данные в таблицу {#insert-data-to-table}
|
||||
|
||||
Чтобы вставить данные:
|
||||
|
||||
**1.** Выполните следующую команду:
|
||||
|
||||
``` bash
|
||||
$ clickhouse client --host 127.0.0.1 --query 'INSERT INTO amazon_train FORMAT CSVWithNames' < ~/amazon/train.csv
|
||||
```
|
||||
|
||||
**2.** Запустите клиент ClickHouse:
|
||||
|
||||
``` bash
|
||||
$ clickhouse client
|
||||
```
|
||||
|
||||
**3.** Проверьте, что данные успешно загрузились:
|
||||
|
||||
``` sql
|
||||
:) SELECT count() FROM amazon_train
|
||||
|
||||
SELECT count()
|
||||
FROM amazon_train
|
||||
|
||||
+-count()-+
|
||||
| 65538 |
|
||||
+---------+
|
||||
```
|
||||
|
||||
## 3. Интегрируйте CatBoost в ClickHouse {#integrate-catboost-into-clickhouse}
|
||||
|
||||
:::note "Примечание"
|
||||
**Опциональный шаг.** Docker-образ содержит все необходимое для запуска CatBoost и ClickHouse.
|
||||
:::
|
||||
Чтобы интегрировать CatBoost в ClickHouse:
|
||||
|
||||
**1.** Создайте библиотеку для оценки модели.
|
||||
|
||||
Наиболее быстрый способ оценить модель CatBoost — это скомпилировать библиотеку `libcatboostmodel.<so|dll|dylib>`. Подробнее о том, как скомпилировать библиотеку, читайте в [документации CatBoost](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html).
|
||||
|
||||
**2.** Создайте в любом месте новую директорию с произвольным названием, например `data` и поместите в нее созданную библиотеку. Docker-образ уже содержит библиотеку `data/libcatboostmodel.so`.
|
||||
|
||||
**3.** Создайте в любом месте новую директорию для конфигурации модели с произвольным названием, например `models`.
|
||||
|
||||
**4.** Создайте файл конфигурации модели с произвольным названием, например `models/amazon_model.xml`.
|
||||
|
||||
**5.** Опишите конфигурацию модели:
|
||||
|
||||
``` xml
|
||||
<models>
|
||||
<model>
|
||||
<!-- Тип модели. В настоящий момент ClickHouse предоставляет только модель catboost. -->
|
||||
<type>catboost</type>
|
||||
<!-- Имя модели. -->
|
||||
<name>amazon</name>
|
||||
<!-- Путь к обученной модели. -->
|
||||
<path>/home/catboost/tutorial/catboost_model.bin</path>
|
||||
<!-- Интервал обновления. -->
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
||||
```
|
||||
|
||||
**6.** Добавьте в конфигурацию ClickHouse путь к CatBoost и конфигурации модели:
|
||||
|
||||
``` xml
|
||||
<!-- Файл etc/clickhouse-server/config.d/models_config.xml. -->
|
||||
<catboost_dynamic_library_path>/home/catboost/data/libcatboostmodel.so</catboost_dynamic_library_path>
|
||||
<models_config>/home/catboost/models/*_model.xml</models_config>
|
||||
```
|
||||
:::note "Примечание"
|
||||
Вы можете позднее изменить путь к конфигурации модели CatBoost без перезагрузки сервера.
|
||||
:::
|
||||
## 4. Запустите вывод модели из SQL {#run-model-inference}
|
||||
|
||||
Для тестирования модели запустите клиент ClickHouse `$ clickhouse client`.
|
||||
|
||||
Проверьте, что модель работает:
|
||||
|
||||
``` sql
|
||||
:) SELECT
|
||||
modelEvaluate('amazon',
|
||||
RESOURCE,
|
||||
MGR_ID,
|
||||
ROLE_ROLLUP_1,
|
||||
ROLE_ROLLUP_2,
|
||||
ROLE_DEPTNAME,
|
||||
ROLE_TITLE,
|
||||
ROLE_FAMILY_DESC,
|
||||
ROLE_FAMILY,
|
||||
ROLE_CODE) > 0 AS prediction,
|
||||
ACTION AS target
|
||||
FROM amazon_train
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
:::note "Примечание"
|
||||
Функция [modelEvaluate](../sql-reference/functions/other-functions.md#function-modelevaluate) возвращает кортежи (tuple) с исходными прогнозами по классам для моделей с несколькими классами.
|
||||
:::
|
||||
Спрогнозируйте вероятность:
|
||||
|
||||
``` sql
|
||||
:) SELECT
|
||||
modelEvaluate('amazon',
|
||||
RESOURCE,
|
||||
MGR_ID,
|
||||
ROLE_ROLLUP_1,
|
||||
ROLE_ROLLUP_2,
|
||||
ROLE_DEPTNAME,
|
||||
ROLE_TITLE,
|
||||
ROLE_FAMILY_DESC,
|
||||
ROLE_FAMILY,
|
||||
ROLE_CODE) AS prediction,
|
||||
1. / (1 + exp(-prediction)) AS probability,
|
||||
ACTION AS target
|
||||
FROM amazon_train
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
:::note "Примечание"
|
||||
Подробнее про функцию [exp()](../sql-reference/functions/math-functions.md).
|
||||
:::
|
||||
Посчитайте логистическую функцию потерь (LogLoss) на всей выборке:
|
||||
|
||||
``` sql
|
||||
:) SELECT -avg(tg * log(prob) + (1 - tg) * log(1 - prob)) AS logloss
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
modelEvaluate('amazon',
|
||||
RESOURCE,
|
||||
MGR_ID,
|
||||
ROLE_ROLLUP_1,
|
||||
ROLE_ROLLUP_2,
|
||||
ROLE_DEPTNAME,
|
||||
ROLE_TITLE,
|
||||
ROLE_FAMILY_DESC,
|
||||
ROLE_FAMILY,
|
||||
ROLE_CODE) AS prediction,
|
||||
1. / (1. + exp(-prediction)) AS prob,
|
||||
ACTION AS tg
|
||||
FROM amazon_train
|
||||
)
|
||||
```
|
||||
|
||||
:::note "Примечание"
|
||||
Подробнее про функции [avg()](../sql-reference/aggregate-functions/reference/avg.md#agg_function-avg), [log()](../sql-reference/functions/math-functions.md).
|
||||
:::
|
@ -7,5 +7,3 @@ sidebar_label: "Руководства"
|
||||
# Руководства {#rukovodstva}
|
||||
|
||||
Подробные пошаговые инструкции, которые помогут вам решать различные задачи с помощью ClickHouse.
|
||||
|
||||
- [Применение модели CatBoost в ClickHouse](apply-catboost-model.md)
|
||||
|
@ -29,7 +29,12 @@ SELECT name, status FROM system.dictionaries;
|
||||
|
||||
## RELOAD MODELS {#query_language-system-reload-models}
|
||||
|
||||
Перегружает все модели [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse), если их конфигурация была обновлена, без перезагрузки сервера.
|
||||
:::note
|
||||
Это утверждение и `SYSTEM RELOAD MODEL` просто выгружают модели catboost из clickhouse-library-bridge. Функция `catboostEvaluate()`
|
||||
загружает модель при первом обращении, если она еще не загружена.
|
||||
:::
|
||||
|
||||
Разгрузите все модели CatBoost.
|
||||
|
||||
**Синтаксис**
|
||||
|
||||
@ -39,12 +44,12 @@ SYSTEM RELOAD MODELS
|
||||
|
||||
## RELOAD MODEL {#query_language-system-reload-model}
|
||||
|
||||
Полностью перегружает модель [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse) `model_name`, если ее конфигурация была обновлена, без перезагрузки сервера.
|
||||
Выгружает модель CatBoost по адресу `модель_путь`.
|
||||
|
||||
**Синтаксис**
|
||||
|
||||
```sql
|
||||
SYSTEM RELOAD MODEL <model_name>
|
||||
SYSTEM RELOAD MODEL <model_path>
|
||||
```
|
||||
|
||||
## RELOAD FUNCTIONS {#query_language-system-reload-functions}
|
||||
|
@ -1,244 +0,0 @@
|
||||
---
|
||||
slug: /zh/guides/apply-catboost-model
|
||||
sidebar_position: 41
|
||||
sidebar_label: "\u5E94\u7528CatBoost\u6A21\u578B"
|
||||
---
|
||||
|
||||
# 在ClickHouse中应用Catboost模型 {#applying-catboost-model-in-clickhouse}
|
||||
|
||||
[CatBoost](https://catboost.ai) 是一个由[Yandex](https://yandex.com/company/)开发的开源免费机器学习库。
|
||||
|
||||
|
||||
通过本篇文档,您将学会如何用SQL语句调用已经存放在Clickhouse中的预训练模型来预测数据。
|
||||
|
||||
|
||||
为了在ClickHouse中应用CatBoost模型,需要进行如下步骤:
|
||||
|
||||
1. [创建数据表](#create-table).
|
||||
2. [将数据插入到表中](#insert-data-to-table).
|
||||
3. [将CatBoost集成到ClickHouse中](#integrate-catboost-into-clickhouse) (可跳过)。
|
||||
4. [从SQL运行模型推断](#run-model-inference).
|
||||
|
||||
有关训练CatBoost模型的详细信息,请参阅 [训练和模型应用](https://catboost.ai/docs/features/training.html#training).
|
||||
|
||||
您可以通过[RELOAD MODEL](https://clickhouse.com/docs/en/sql-reference/statements/system/#query_language-system-reload-model)与[RELOAD MODELS](https://clickhouse.com/docs/en/sql-reference/statements/system/#query_language-system-reload-models)语句来重载CatBoost模型。
|
||||
|
||||
## 先决条件 {#prerequisites}
|
||||
|
||||
请先安装 [Docker](https://docs.docker.com/install/)。
|
||||
|
||||
!!! note "注"
|
||||
[Docker](https://www.docker.com) 是一个软件平台,用户可以用Docker来创建独立于已有系统并集成了CatBoost和ClickHouse的容器。
|
||||
|
||||
在应用CatBoost模型之前:
|
||||
|
||||
**1.** 从容器仓库拉取示例docker镜像 (https://hub.docker.com/r/yandex/tutorial-catboost-clickhouse) :
|
||||
|
||||
``` bash
|
||||
$ docker pull yandex/tutorial-catboost-clickhouse
|
||||
```
|
||||
|
||||
此示例Docker镜像包含运行CatBoost和ClickHouse所需的所有内容:代码、运行时、库、环境变量和配置文件。
|
||||
|
||||
**2.** 确保已成功拉取Docker镜像:
|
||||
|
||||
``` bash
|
||||
$ docker image ls
|
||||
REPOSITORY TAG IMAGE ID CREATED SIZE
|
||||
yandex/tutorial-catboost-clickhouse latest 622e4d17945b 22 hours ago 1.37GB
|
||||
```
|
||||
|
||||
**3.** 基于此镜像启动一个Docker容器:
|
||||
|
||||
``` bash
|
||||
$ docker run -it -p 8888:8888 yandex/tutorial-catboost-clickhouse
|
||||
```
|
||||
|
||||
## 1. 创建数据表 {#create-table}
|
||||
|
||||
为训练样本创建ClickHouse表:
|
||||
|
||||
**1.** 在交互模式下启动ClickHouse控制台客户端:
|
||||
|
||||
``` bash
|
||||
$ clickhouse client
|
||||
```
|
||||
|
||||
!!! note "注"
|
||||
ClickHouse服务器已经在Docker容器内运行。
|
||||
|
||||
**2.** 使用以下命令创建表:
|
||||
|
||||
``` sql
|
||||
:) CREATE TABLE amazon_train
|
||||
(
|
||||
date Date MATERIALIZED today(),
|
||||
ACTION UInt8,
|
||||
RESOURCE UInt32,
|
||||
MGR_ID UInt32,
|
||||
ROLE_ROLLUP_1 UInt32,
|
||||
ROLE_ROLLUP_2 UInt32,
|
||||
ROLE_DEPTNAME UInt32,
|
||||
ROLE_TITLE UInt32,
|
||||
ROLE_FAMILY_DESC UInt32,
|
||||
ROLE_FAMILY UInt32,
|
||||
ROLE_CODE UInt32
|
||||
)
|
||||
ENGINE = MergeTree ORDER BY date
|
||||
```
|
||||
|
||||
**3.** 从ClickHouse控制台客户端退出:
|
||||
|
||||
``` sql
|
||||
:) exit
|
||||
```
|
||||
|
||||
## 2. 将数据插入到表中 {#insert-data-to-table}
|
||||
|
||||
插入数据:
|
||||
|
||||
**1.** 运行以下命令:
|
||||
|
||||
``` bash
|
||||
$ clickhouse client --host 127.0.0.1 --query 'INSERT INTO amazon_train FORMAT CSVWithNames' < ~/amazon/train.csv
|
||||
```
|
||||
|
||||
**2.** 在交互模式下启动ClickHouse控制台客户端:
|
||||
|
||||
``` bash
|
||||
$ clickhouse client
|
||||
```
|
||||
|
||||
**3.** 确保数据已上传:
|
||||
|
||||
``` sql
|
||||
:) SELECT count() FROM amazon_train
|
||||
|
||||
SELECT count()
|
||||
FROM amazon_train
|
||||
|
||||
+-count()-+
|
||||
| 65538 |
|
||||
+-------+
|
||||
```
|
||||
|
||||
## 3. 将CatBoost集成到ClickHouse中 {#integrate-catboost-into-clickhouse}
|
||||
|
||||
!!! note "注"
|
||||
**可跳过。** 示例Docker映像已经包含了运行CatBoost和ClickHouse所需的所有内容。
|
||||
|
||||
为了将CatBoost集成进ClickHouse,需要进行如下步骤:
|
||||
|
||||
**1.** 构建评估库。
|
||||
|
||||
评估CatBoost模型的最快方法是编译 `libcatboostmodel.<so|dll|dylib>` 库文件.
|
||||
|
||||
有关如何构建库文件的详细信息,请参阅 [CatBoost文件](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html).
|
||||
|
||||
**2.** 创建一个新目录(位置与名称可随意指定), 如 `data` 并将创建的库文件放入其中。 示例Docker镜像已经包含了库 `data/libcatboostmodel.so`.
|
||||
|
||||
**3.** 创建一个新目录来放配置模型, 如 `models`.
|
||||
|
||||
**4.** 创建一个模型配置文件,如 `models/amazon_model.xml`.
|
||||
|
||||
**5.** 修改模型配置:
|
||||
|
||||
``` xml
|
||||
<models>
|
||||
<model>
|
||||
<!-- Model type. Now catboost only. -->
|
||||
<type>catboost</type>
|
||||
<!-- Model name. -->
|
||||
<name>amazon</name>
|
||||
<!-- Path to trained model. -->
|
||||
<path>/home/catboost/tutorial/catboost_model.bin</path>
|
||||
<!-- Update interval. -->
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
||||
```
|
||||
|
||||
**6.** 将CatBoost库文件的路径和模型配置添加到ClickHouse配置:
|
||||
|
||||
``` xml
|
||||
<!-- File etc/clickhouse-server/config.d/models_config.xml. -->
|
||||
<catboost_dynamic_library_path>/home/catboost/data/libcatboostmodel.so</catboost_dynamic_library_path>
|
||||
<models_config>/home/catboost/models/*_model.xml</models_config>
|
||||
```
|
||||
|
||||
## 4. 使用SQL调用预测模型 {#run-model-inference}
|
||||
|
||||
为了测试模型是否正常,可以使用ClickHouse客户端 `$ clickhouse client`.
|
||||
|
||||
让我们确保模型能正常工作:
|
||||
|
||||
``` sql
|
||||
:) SELECT
|
||||
modelEvaluate('amazon',
|
||||
RESOURCE,
|
||||
MGR_ID,
|
||||
ROLE_ROLLUP_1,
|
||||
ROLE_ROLLUP_2,
|
||||
ROLE_DEPTNAME,
|
||||
ROLE_TITLE,
|
||||
ROLE_FAMILY_DESC,
|
||||
ROLE_FAMILY,
|
||||
ROLE_CODE) > 0 AS prediction,
|
||||
ACTION AS target
|
||||
FROM amazon_train
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
!!! note "注"
|
||||
函数 [modelEvaluate](../sql-reference/functions/other-functions.md#function-modelevaluate) 会对多类别模型返回一个元组,其中包含每一类别的原始预测值。
|
||||
|
||||
执行预测:
|
||||
|
||||
``` sql
|
||||
:) SELECT
|
||||
modelEvaluate('amazon',
|
||||
RESOURCE,
|
||||
MGR_ID,
|
||||
ROLE_ROLLUP_1,
|
||||
ROLE_ROLLUP_2,
|
||||
ROLE_DEPTNAME,
|
||||
ROLE_TITLE,
|
||||
ROLE_FAMILY_DESC,
|
||||
ROLE_FAMILY,
|
||||
ROLE_CODE) AS prediction,
|
||||
1. / (1 + exp(-prediction)) AS probability,
|
||||
ACTION AS target
|
||||
FROM amazon_train
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
!!! note "注"
|
||||
查看函数说明 [exp()](../sql-reference/functions/math-functions.md) 。
|
||||
|
||||
让我们计算样本的LogLoss:
|
||||
|
||||
``` sql
|
||||
:) SELECT -avg(tg * log(prob) + (1 - tg) * log(1 - prob)) AS logloss
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
modelEvaluate('amazon',
|
||||
RESOURCE,
|
||||
MGR_ID,
|
||||
ROLE_ROLLUP_1,
|
||||
ROLE_ROLLUP_2,
|
||||
ROLE_DEPTNAME,
|
||||
ROLE_TITLE,
|
||||
ROLE_FAMILY_DESC,
|
||||
ROLE_FAMILY,
|
||||
ROLE_CODE) AS prediction,
|
||||
1. / (1. + exp(-prediction)) AS prob,
|
||||
ACTION AS tg
|
||||
FROM amazon_train
|
||||
)
|
||||
```
|
||||
|
||||
!!! note "注"
|
||||
查看函数说明 [avg()](../sql-reference/aggregate-functions/reference/avg.md#agg_function-avg) 和 [log()](../sql-reference/functions/math-functions.md) 。
|
||||
|
||||
[原始文章](https://clickhouse.com/docs/en/guides/apply_catboost_model/) <!--hide-->
|
@ -9,6 +9,5 @@ sidebar_label: ClickHouse指南
|
||||
列出了如何使用 Clickhouse 解决各种任务的详细说明:
|
||||
|
||||
- [关于简单集群设置的教程](../getting-started/tutorial.md)
|
||||
- [在ClickHouse中应用CatBoost模型](apply-catboost-model.md)
|
||||
|
||||
[原始文章](https://clickhouse.com/docs/en/guides/) <!--hide-->
|
||||
|
@ -54,7 +54,7 @@ else ()
|
||||
endif ()
|
||||
|
||||
if (NOT USE_MUSL)
|
||||
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to Library dictionary source" ${ENABLE_CLICKHOUSE_ALL})
|
||||
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to external dynamically loaded libraries" ${ENABLE_CLICKHOUSE_ALL})
|
||||
endif ()
|
||||
|
||||
# https://presentations.clickhouse.com/matemarketing_2020/
|
||||
|
@ -1,6 +1,8 @@
|
||||
include(${ClickHouse_SOURCE_DIR}/cmake/split_debug_symbols.cmake)
|
||||
|
||||
set (CLICKHOUSE_LIBRARY_BRIDGE_SOURCES
|
||||
CatBoostLibraryHandler.cpp
|
||||
CatBoostLibraryHandlerFactory.cpp
|
||||
ExternalDictionaryLibraryAPI.cpp
|
||||
ExternalDictionaryLibraryHandler.cpp
|
||||
ExternalDictionaryLibraryHandlerFactory.cpp
|
||||
|
49
programs/library-bridge/CatBoostLibraryAPI.h
Normal file
49
programs/library-bridge/CatBoostLibraryAPI.h
Normal file
@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
|
||||
// Function pointer typedefs and names of libcatboost.so functions used by ClickHouse
|
||||
struct CatBoostLibraryAPI
|
||||
{
|
||||
using ModelCalcerHandle = void;
|
||||
|
||||
using ModelCalcerCreateFunc = ModelCalcerHandle * (*)();
|
||||
static constexpr const char * ModelCalcerCreateName = "ModelCalcerCreate";
|
||||
|
||||
using ModelCalcerDeleteFunc = void (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * ModelCalcerDeleteName = "ModelCalcerDelete";
|
||||
|
||||
using GetErrorStringFunc = const char * (*)();
|
||||
static constexpr const char * GetErrorStringName = "GetErrorString";
|
||||
|
||||
using LoadFullModelFromFileFunc = bool (*)(ModelCalcerHandle *, const char *);
|
||||
static constexpr const char * LoadFullModelFromFileName = "LoadFullModelFromFile";
|
||||
|
||||
using CalcModelPredictionFlatFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, double *, size_t);
|
||||
static constexpr const char * CalcModelPredictionFlatName = "CalcModelPredictionFlat";
|
||||
|
||||
using CalcModelPredictionFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const char ***, size_t, double *, size_t);
|
||||
static constexpr const char * CalcModelPredictionName = "CalcModelPrediction";
|
||||
|
||||
using CalcModelPredictionWithHashedCatFeaturesFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const int **, size_t, double *, size_t);
|
||||
static constexpr const char * CalcModelPredictionWithHashedCatFeaturesName = "CalcModelPredictionWithHashedCatFeatures";
|
||||
|
||||
using GetStringCatFeatureHashFunc = int (*)(const char *, size_t);
|
||||
static constexpr const char * GetStringCatFeatureHashName = "GetStringCatFeatureHash";
|
||||
|
||||
using GetIntegerCatFeatureHashFunc = int (*)(uint64_t);
|
||||
static constexpr const char * GetIntegerCatFeatureHashName = "GetIntegerCatFeatureHash";
|
||||
|
||||
using GetFloatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetFloatFeaturesCountName = "GetFloatFeaturesCount";
|
||||
|
||||
using GetCatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetCatFeaturesCountName = "GetCatFeaturesCount";
|
||||
|
||||
using GetTreeCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetTreeCountName = "GetTreeCount";
|
||||
|
||||
using GetDimensionsCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetDimensionsCountName = "GetDimensionsCount";
|
||||
};
|
389
programs/library-bridge/CatBoostLibraryHandler.cpp
Normal file
389
programs/library-bridge/CatBoostLibraryHandler.cpp
Normal file
@ -0,0 +1,389 @@
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int CANNOT_APPLY_CATBOOST_MODEL;
|
||||
extern const int CANNOT_LOAD_CATBOOST_MODEL;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib)
|
||||
{
|
||||
ModelCalcerCreate = lib.get<CatBoostLibraryAPI::ModelCalcerCreateFunc>(CatBoostLibraryAPI::ModelCalcerCreateName);
|
||||
ModelCalcerDelete = lib.get<CatBoostLibraryAPI::ModelCalcerDeleteFunc>(CatBoostLibraryAPI::ModelCalcerDeleteName);
|
||||
GetErrorString = lib.get<CatBoostLibraryAPI::GetErrorStringFunc>(CatBoostLibraryAPI::GetErrorStringName);
|
||||
LoadFullModelFromFile = lib.get<CatBoostLibraryAPI::LoadFullModelFromFileFunc>(CatBoostLibraryAPI::LoadFullModelFromFileName);
|
||||
CalcModelPredictionFlat = lib.get<CatBoostLibraryAPI::CalcModelPredictionFlatFunc>(CatBoostLibraryAPI::CalcModelPredictionFlatName);
|
||||
CalcModelPrediction = lib.get<CatBoostLibraryAPI::CalcModelPredictionFunc>(CatBoostLibraryAPI::CalcModelPredictionName);
|
||||
CalcModelPredictionWithHashedCatFeatures = lib.get<CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc>(CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesName);
|
||||
GetStringCatFeatureHash = lib.get<CatBoostLibraryAPI::GetStringCatFeatureHashFunc>(CatBoostLibraryAPI::GetStringCatFeatureHashName);
|
||||
GetIntegerCatFeatureHash = lib.get<CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc>(CatBoostLibraryAPI::GetIntegerCatFeatureHashName);
|
||||
GetFloatFeaturesCount = lib.get<CatBoostLibraryAPI::GetFloatFeaturesCountFunc>(CatBoostLibraryAPI::GetFloatFeaturesCountName);
|
||||
GetCatFeaturesCount = lib.get<CatBoostLibraryAPI::GetCatFeaturesCountFunc>(CatBoostLibraryAPI::GetCatFeaturesCountName);
|
||||
GetTreeCount = lib.tryGet<CatBoostLibraryAPI::GetTreeCountFunc>(CatBoostLibraryAPI::GetTreeCountName);
|
||||
GetDimensionsCount = lib.tryGet<CatBoostLibraryAPI::GetDimensionsCountFunc>(CatBoostLibraryAPI::GetDimensionsCountName);
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::CatBoostLibraryHandler(
|
||||
const std::string & library_path,
|
||||
const std::string & model_path)
|
||||
: loading_start_time(std::chrono::system_clock::now())
|
||||
, library(std::make_shared<SharedLibrary>(library_path))
|
||||
, api(*library)
|
||||
{
|
||||
model_calcer_handle = api.ModelCalcerCreate();
|
||||
|
||||
if (!api.LoadFullModelFromFile(model_calcer_handle, model_path.c_str()))
|
||||
{
|
||||
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
|
||||
"Cannot load CatBoost model: {}", api.GetErrorString());
|
||||
}
|
||||
|
||||
float_features_count = api.GetFloatFeaturesCount(model_calcer_handle);
|
||||
cat_features_count = api.GetCatFeaturesCount(model_calcer_handle);
|
||||
|
||||
tree_count = 1;
|
||||
if (api.GetDimensionsCount)
|
||||
tree_count = api.GetDimensionsCount(model_calcer_handle);
|
||||
|
||||
loading_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - loading_start_time);
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::~CatBoostLibraryHandler()
|
||||
{
|
||||
api.ModelCalcerDelete(model_calcer_handle);
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point CatBoostLibraryHandler::getLoadingStartTime() const
|
||||
{
|
||||
return loading_start_time;
|
||||
}
|
||||
|
||||
std::chrono::milliseconds CatBoostLibraryHandler::getLoadingDuration() const
|
||||
{
|
||||
return loading_duration;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
template <typename T>
|
||||
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column->size();
|
||||
FieldVisitorConvertToNumber<T> visitor;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
/// TODO: Replace with column visitor.
|
||||
Field field;
|
||||
column->get(i, field);
|
||||
*buffer = applyVisitor(visitor, field);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
|
||||
PODArray<char> placeFixedStringColumn(const ColumnFixedString & column, const char ** buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column.size();
|
||||
size_t str_size = column.getN();
|
||||
PODArray<char> data(size * (str_size + 1));
|
||||
char * data_ptr = data.data();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto ref = column.getDataAt(i);
|
||||
memcpy(data_ptr, ref.data, ref.size);
|
||||
data_ptr[ref.size] = 0;
|
||||
*buffer = data_ptr;
|
||||
data_ptr += ref.size + 1;
|
||||
buffer += features_count;
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
|
||||
template <typename T>
|
||||
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const T** buffer)
|
||||
{
|
||||
if (size == 0)
|
||||
return nullptr;
|
||||
|
||||
size_t column_size = columns[offset]->size();
|
||||
auto data_column = ColumnVector<T>::create(size * column_size);
|
||||
T * data = data_column->getData().data();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (column->isNumeric())
|
||||
placeColumnAsNumber(column, data + i, size);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*buffer = data;
|
||||
++buffer;
|
||||
data += size;
|
||||
}
|
||||
|
||||
return data_column;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns data which was used for fixed string columns.
|
||||
/// Buffer should contains column->size() values, each value contains size strings.
|
||||
std::vector<PODArray<char>> placeStringColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer)
|
||||
{
|
||||
if (size == 0)
|
||||
return {};
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
|
||||
placeStringColumn(*column_string, buffer + i, size);
|
||||
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
|
||||
else
|
||||
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
|
||||
void fillCatFeaturesBuffer(
|
||||
const char *** cat_features, const char ** buffer,
|
||||
size_t column_size, size_t cat_features_count)
|
||||
{
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*cat_features = buffer;
|
||||
++cat_features;
|
||||
buffer += cat_features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Calc hash for string cat feature at ps positions.
|
||||
template <typename Column>
|
||||
void calcStringHashes(const Column * column, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
|
||||
{
|
||||
size_t column_size = column->size();
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
auto ref = column->getDataAt(j);
|
||||
const_cast<int *>(*buffer)[ps] = api.GetStringCatFeatureHash(ref.data, ref.size);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
|
||||
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
|
||||
{
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
const_cast<int *>(*buffer)[ps] = api.GetIntegerCatFeatureHash((*buffer)[ps]);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// buffer contains column->size() rows and size columns.
|
||||
/// For int cat features calc hash inplace.
|
||||
/// For string cat features calc hash from column rows.
|
||||
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
|
||||
{
|
||||
if (size == 0)
|
||||
return;
|
||||
size_t column_size = columns[offset]->size();
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
|
||||
calcStringHashes(column_string, i, buffer, api);
|
||||
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
calcStringHashes(column_fixed_string, i, buffer, api);
|
||||
else
|
||||
calcIntHashes(column_size, i, buffer, api);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
|
||||
/// * CalcModelPredictionFlat if no cat features
|
||||
/// * CalcModelPrediction if all cat features are strings
|
||||
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
|
||||
ColumnFloat64::MutablePtr CatBoostLibraryHandler::evalImpl(
|
||||
const ColumnRawPtrs & columns,
|
||||
bool cat_features_are_strings) const
|
||||
{
|
||||
std::string error_msg = "Error occurred while applying CatBoost model: ";
|
||||
size_t column_size = columns.front()->size();
|
||||
|
||||
auto result = ColumnFloat64::create(column_size * tree_count);
|
||||
auto * result_buf = result->getData().data();
|
||||
|
||||
if (!column_size)
|
||||
return result;
|
||||
|
||||
/// Prepare float features.
|
||||
PODArray<const float *> float_features(column_size);
|
||||
auto * float_features_buf = float_features.data();
|
||||
/// Store all float data into single column. float_features is a list of pointers to it.
|
||||
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
|
||||
|
||||
if (cat_features_count == 0)
|
||||
{
|
||||
if (!api.CalcModelPredictionFlat(model_calcer_handle, column_size,
|
||||
float_features_buf, float_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
|
||||
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Prepare cat features.
|
||||
if (cat_features_are_strings)
|
||||
{
|
||||
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
|
||||
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
|
||||
PODArray<const char **> cat_features(column_size);
|
||||
auto * cat_features_buf = cat_features.data();
|
||||
|
||||
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
|
||||
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
|
||||
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
|
||||
cat_features_count, cat_features_holder.data());
|
||||
|
||||
if (!api.CalcModelPrediction(model_calcer_handle, column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
PODArray<const int *> cat_features(column_size);
|
||||
auto * cat_features_buf = cat_features.data();
|
||||
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
|
||||
cat_features_count, cat_features_buf);
|
||||
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf, api);
|
||||
if (!api.CalcModelPredictionWithHashedCatFeatures(
|
||||
model_calcer_handle, column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t CatBoostLibraryHandler::getTreeCount() const
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
return tree_count;
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostLibraryHandler::evaluate(const ColumnRawPtrs & columns) const
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (columns.empty())
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got empty columns list for CatBoost model.");
|
||||
|
||||
if (columns.size() != float_features_count + cat_features_count)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Number of columns is different with number of features: columns size {} float features size {} + cat features size {}",
|
||||
columns.size(),
|
||||
float_features_count,
|
||||
cat_features_count);
|
||||
|
||||
for (size_t i = 0; i < float_features_count; ++i)
|
||||
{
|
||||
if (!columns[i]->isNumeric())
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric to make float feature.", i);
|
||||
}
|
||||
}
|
||||
|
||||
bool cat_features_are_strings = true;
|
||||
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
|
||||
{
|
||||
const auto * column = columns[i];
|
||||
if (column->isNumeric())
|
||||
{
|
||||
cat_features_are_strings = false;
|
||||
}
|
||||
else if (!(typeid_cast<const ColumnString *>(column)
|
||||
|| typeid_cast<const ColumnFixedString *>(column)))
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric or string.", i);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = evalImpl(columns, cat_features_are_strings);
|
||||
|
||||
if (tree_count == 1)
|
||||
return result;
|
||||
|
||||
size_t column_size = columns.front()->size();
|
||||
auto * result_buf = result->getData().data();
|
||||
|
||||
/// Multiple trees case. Copy data to several columns.
|
||||
MutableColumns mutable_columns(tree_count);
|
||||
std::vector<Float64 *> column_ptrs(tree_count);
|
||||
for (size_t i = 0; i < tree_count; ++i)
|
||||
{
|
||||
auto col = ColumnFloat64::create(column_size);
|
||||
column_ptrs[i] = col->getData().data();
|
||||
mutable_columns[i] = std::move(col);
|
||||
}
|
||||
|
||||
Float64 * data = result_buf;
|
||||
for (size_t row = 0; row < column_size; ++row)
|
||||
{
|
||||
for (size_t i = 0; i < tree_count; ++i)
|
||||
{
|
||||
*column_ptrs[i] = *data;
|
||||
++column_ptrs[i];
|
||||
++data;
|
||||
}
|
||||
}
|
||||
|
||||
return ColumnTuple::create(std::move(mutable_columns));
|
||||
}
|
||||
|
||||
}
|
78
programs/library-bridge/CatBoostLibraryHandler.h
Normal file
78
programs/library-bridge/CatBoostLibraryHandler.h
Normal file
@ -0,0 +1,78 @@
|
||||
#pragma once
|
||||
|
||||
#include "CatBoostLibraryAPI.h"
|
||||
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Common/SharedLibrary.h>
|
||||
#include <base/defines.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Abstracts access to the CatBoost shared library.
|
||||
class CatBoostLibraryHandler
|
||||
{
|
||||
public:
|
||||
/// Holds pointers to CatBoost library functions
|
||||
struct APIHolder
|
||||
{
|
||||
explicit APIHolder(SharedLibrary & lib);
|
||||
|
||||
// NOLINTBEGIN(readability-identifier-naming)
|
||||
CatBoostLibraryAPI::ModelCalcerCreateFunc ModelCalcerCreate;
|
||||
CatBoostLibraryAPI::ModelCalcerDeleteFunc ModelCalcerDelete;
|
||||
CatBoostLibraryAPI::GetErrorStringFunc GetErrorString;
|
||||
CatBoostLibraryAPI::LoadFullModelFromFileFunc LoadFullModelFromFile;
|
||||
CatBoostLibraryAPI::CalcModelPredictionFlatFunc CalcModelPredictionFlat;
|
||||
CatBoostLibraryAPI::CalcModelPredictionFunc CalcModelPrediction;
|
||||
CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc CalcModelPredictionWithHashedCatFeatures;
|
||||
CatBoostLibraryAPI::GetStringCatFeatureHashFunc GetStringCatFeatureHash;
|
||||
CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc GetIntegerCatFeatureHash;
|
||||
CatBoostLibraryAPI::GetFloatFeaturesCountFunc GetFloatFeaturesCount;
|
||||
CatBoostLibraryAPI::GetCatFeaturesCountFunc GetCatFeaturesCount;
|
||||
CatBoostLibraryAPI::GetTreeCountFunc GetTreeCount;
|
||||
CatBoostLibraryAPI::GetDimensionsCountFunc GetDimensionsCount;
|
||||
// NOLINTEND(readability-identifier-naming)
|
||||
};
|
||||
|
||||
CatBoostLibraryHandler(
|
||||
const String & library_path,
|
||||
const String & model_path);
|
||||
|
||||
~CatBoostLibraryHandler();
|
||||
|
||||
std::chrono::system_clock::time_point getLoadingStartTime() const;
|
||||
std::chrono::milliseconds getLoadingDuration() const;
|
||||
|
||||
size_t getTreeCount() const;
|
||||
|
||||
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
|
||||
|
||||
private:
|
||||
std::chrono::system_clock::time_point loading_start_time;
|
||||
std::chrono::milliseconds loading_duration;
|
||||
|
||||
const SharedLibraryPtr library;
|
||||
const APIHolder api;
|
||||
|
||||
mutable std::mutex mutex;
|
||||
|
||||
CatBoostLibraryAPI::ModelCalcerHandle * model_calcer_handle TSA_GUARDED_BY(mutex) TSA_PT_GUARDED_BY(mutex);
|
||||
|
||||
size_t float_features_count TSA_GUARDED_BY(mutex);
|
||||
size_t cat_features_count TSA_GUARDED_BY(mutex);
|
||||
size_t tree_count TSA_GUARDED_BY(mutex);
|
||||
|
||||
ColumnFloat64::MutablePtr evalImpl(const ColumnRawPtrs & columns, bool cat_features_are_strings) const TSA_REQUIRES(mutex);
|
||||
};
|
||||
|
||||
using CatBoostLibraryHandlerPtr = std::shared_ptr<CatBoostLibraryHandler>;
|
||||
|
||||
}
|
80
programs/library-bridge/CatBoostLibraryHandlerFactory.cpp
Normal file
80
programs/library-bridge/CatBoostLibraryHandlerFactory.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include "CatBoostLibraryHandlerFactory.h"
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance()
|
||||
{
|
||||
static CatBoostLibraryHandlerFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
CatBoostLibraryHandlerFactory::CatBoostLibraryHandlerFactory()
|
||||
: log(&Poco::Logger::get("CatBoostLibraryHandlerFactory"))
|
||||
{
|
||||
}
|
||||
|
||||
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::tryGetModel(const String & model_path, const String & library_path, bool create_if_not_found)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
auto handler = library_handlers.find(model_path);
|
||||
bool found = (handler != library_handlers.end());
|
||||
|
||||
if (found)
|
||||
return handler->second;
|
||||
else
|
||||
{
|
||||
if (create_if_not_found)
|
||||
{
|
||||
auto new_handler = std::make_shared<CatBoostLibraryHandler>(library_path, model_path);
|
||||
library_handlers.emplace(model_path, new_handler);
|
||||
LOG_DEBUG(log, "Loaded catboost library handler for model path '{}'", model_path);
|
||||
return new_handler;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::removeModel(const String & model_path)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
bool deleted = library_handlers.erase(model_path);
|
||||
if (!deleted)
|
||||
{
|
||||
LOG_DEBUG(log, "Cannot unload catboost library handler for model path '{}'", model_path);
|
||||
return;
|
||||
}
|
||||
LOG_DEBUG(log, "Unloaded catboost library handler for model path '{}'", model_path);
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::removeAllModels()
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
library_handlers.clear();
|
||||
LOG_DEBUG(log, "Unloaded all catboost library handlers");
|
||||
}
|
||||
|
||||
ExternalModelInfos CatBoostLibraryHandlerFactory::getModelInfos()
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
ExternalModelInfos result;
|
||||
|
||||
for (const auto & handler : library_handlers)
|
||||
result.push_back({
|
||||
.model_path = handler.first,
|
||||
.model_type = "catboost",
|
||||
.loading_start_time = handler.second->getLoadingStartTime(),
|
||||
.loading_duration = handler.second->getLoadingDuration()
|
||||
});
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
}
|
37
programs/library-bridge/CatBoostLibraryHandlerFactory.h
Normal file
37
programs/library-bridge/CatBoostLibraryHandlerFactory.h
Normal file
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
|
||||
#include <base/defines.h>
|
||||
#include <Common/ExternalModelInfo.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class CatBoostLibraryHandlerFactory final : private boost::noncopyable
|
||||
{
|
||||
public:
|
||||
static CatBoostLibraryHandlerFactory & instance();
|
||||
|
||||
CatBoostLibraryHandlerFactory();
|
||||
|
||||
CatBoostLibraryHandlerPtr tryGetModel(const String & model_path, const String & library_path, bool create_if_not_found);
|
||||
|
||||
void removeModel(const String & model_path);
|
||||
void removeAllModels();
|
||||
|
||||
ExternalModelInfos getModelInfos();
|
||||
|
||||
private:
|
||||
/// map: model path --> catboost library handler
|
||||
std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::mutex mutex;
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
}
|
@ -50,6 +50,6 @@ private:
|
||||
void * lib_data;
|
||||
};
|
||||
|
||||
using SharedLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
|
||||
using ExternalDictionaryLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
|
||||
|
||||
}
|
||||
|
@ -1,37 +1,40 @@
|
||||
#include "ExternalDictionaryLibraryHandlerFactory.h"
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
SharedLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const std::string & dictionary_id)
|
||||
ExternalDictionaryLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const String & dictionary_id)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
auto library_handler = library_handlers.find(dictionary_id);
|
||||
|
||||
if (library_handler != library_handlers.end())
|
||||
return library_handler->second;
|
||||
|
||||
if (auto handler = library_handlers.find(dictionary_id); handler != library_handlers.end())
|
||||
return handler->second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaryLibraryHandlerFactory::create(
|
||||
const std::string & dictionary_id,
|
||||
const std::string & library_path,
|
||||
const std::vector<std::string> & library_settings,
|
||||
const String & dictionary_id,
|
||||
const String & library_path,
|
||||
const std::vector<String> & library_settings,
|
||||
const Block & sample_block,
|
||||
const std::vector<std::string> & attributes_names)
|
||||
const std::vector<String> & attributes_names)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
if (!library_handlers.contains(dictionary_id))
|
||||
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
|
||||
else
|
||||
|
||||
if (library_handlers.contains(dictionary_id))
|
||||
{
|
||||
LOG_WARNING(&Poco::Logger::get("ExternalDictionaryLibraryHandlerFactory"), "Library handler with dictionary id {} already exists", dictionary_id);
|
||||
return;
|
||||
}
|
||||
|
||||
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
|
||||
}
|
||||
|
||||
|
||||
bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id)
|
||||
bool ExternalDictionaryLibraryHandlerFactory::clone(const String & from_dictionary_id, const String & to_dictionary_id)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
auto from_library_handler = library_handlers.find(from_dictionary_id);
|
||||
@ -45,7 +48,7 @@ bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dic
|
||||
}
|
||||
|
||||
|
||||
bool ExternalDictionaryLibraryHandlerFactory::remove(const std::string & dictionary_id)
|
||||
bool ExternalDictionaryLibraryHandlerFactory::remove(const String & dictionary_id)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
/// extDict_libDelete is called in destructor.
|
||||
|
@ -17,22 +17,22 @@ class ExternalDictionaryLibraryHandlerFactory final : private boost::noncopyable
|
||||
public:
|
||||
static ExternalDictionaryLibraryHandlerFactory & instance();
|
||||
|
||||
SharedLibraryHandlerPtr get(const std::string & dictionary_id);
|
||||
ExternalDictionaryLibraryHandlerPtr get(const String & dictionary_id);
|
||||
|
||||
void create(
|
||||
const std::string & dictionary_id,
|
||||
const std::string & library_path,
|
||||
const std::vector<std::string> & library_settings,
|
||||
const String & dictionary_id,
|
||||
const String & library_path,
|
||||
const std::vector<String> & library_settings,
|
||||
const Block & sample_block,
|
||||
const std::vector<std::string> & attributes_names);
|
||||
const std::vector<String> & attributes_names);
|
||||
|
||||
bool clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id);
|
||||
bool clone(const String & from_dictionary_id, const String & to_dictionary_id);
|
||||
|
||||
bool remove(const std::string & dictionary_id);
|
||||
bool remove(const String & dictionary_id);
|
||||
|
||||
private:
|
||||
/// map: dict_id -> sharedLibraryHandler
|
||||
std::unordered_map<std::string, SharedLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::unordered_map<String, ExternalDictionaryLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
||||
|
@ -27,12 +27,16 @@ std::unique_ptr<HTTPRequestHandler> LibraryBridgeHandlerFactory::createRequestHa
|
||||
{
|
||||
if (uri.getPath() == "/extdict_ping")
|
||||
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
|
||||
else if (uri.getPath() == "/catboost_ping")
|
||||
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
|
||||
}
|
||||
|
||||
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
|
||||
{
|
||||
if (uri.getPath() == "/extdict_request")
|
||||
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
|
||||
else if (uri.getPath() == "/catboost_request")
|
||||
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -1,24 +1,32 @@
|
||||
#include "LibraryBridgeHandlers.h"
|
||||
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
#include "CatBoostLibraryHandlerFactory.h"
|
||||
#include "ExternalDictionaryLibraryHandler.h"
|
||||
#include "ExternalDictionaryLibraryHandlerFactory.h"
|
||||
|
||||
#include <Formats/FormatFactory.h>
|
||||
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <Common/BridgeProtocolVersion.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Poco/Net/HTMLForm.h>
|
||||
#include <Poco/Net/HTTPServerRequest.h>
|
||||
#include <Poco/Net/HTTPServerResponse.h>
|
||||
#include <Poco/Net/HTMLForm.h>
|
||||
#include <Poco/ThreadPool.h>
|
||||
#include <Processors/Formats/IOutputFormat.h>
|
||||
#include <Processors/Formats/IInputFormat.h>
|
||||
#include <QueryPipeline/QueryPipeline.h>
|
||||
#include <Processors/Executors/CompletedPipelineExecutor.h>
|
||||
#include <Processors/Executors/PullingPipelineExecutor.h>
|
||||
#include <Processors/Formats/IInputFormat.h>
|
||||
#include <Processors/Formats/IOutputFormat.h>
|
||||
#include <Processors/Sources/SourceFromSingleChunk.h>
|
||||
#include <QueryPipeline/Pipe.h>
|
||||
#include <QueryPipeline/QueryPipeline.h>
|
||||
#include <Server/HTTP/HTMLForm.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
|
||||
#include <Formats/NativeReader.h>
|
||||
#include <Formats/NativeWriter.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -31,7 +39,7 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
void processError(HTTPServerResponse & response, const std::string & message)
|
||||
void processError(HTTPServerResponse & response, const String & message)
|
||||
{
|
||||
response.setStatusAndReason(HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
|
||||
|
||||
@ -41,7 +49,7 @@ namespace
|
||||
LOG_WARNING(&Poco::Logger::get("LibraryBridge"), fmt::runtime(message));
|
||||
}
|
||||
|
||||
std::shared_ptr<Block> parseColumns(std::string && column_string)
|
||||
std::shared_ptr<Block> parseColumns(String && column_string)
|
||||
{
|
||||
auto sample_block = std::make_shared<Block>();
|
||||
auto names_and_types = NamesAndTypesList::parse(column_string);
|
||||
@ -59,10 +67,10 @@ namespace
|
||||
return ids;
|
||||
}
|
||||
|
||||
std::vector<std::string> parseNamesFromBinary(const std::string & names_string)
|
||||
std::vector<String> parseNamesFromBinary(const String & names_string)
|
||||
{
|
||||
ReadBufferFromString buf(names_string);
|
||||
std::vector<std::string> names;
|
||||
std::vector<String> names;
|
||||
readVectorBinary(names, buf);
|
||||
return names;
|
||||
}
|
||||
@ -79,13 +87,15 @@ static void writeData(Block data, OutputFormatPtr format)
|
||||
executor.execute();
|
||||
}
|
||||
|
||||
|
||||
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
@ -97,7 +107,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
version = 0; /// assumed version for too old servers which do not send a version
|
||||
else
|
||||
{
|
||||
String version_str = params.get("version");
|
||||
const String & version_str = params.get("version");
|
||||
if (!tryParse(version, version_str))
|
||||
{
|
||||
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
|
||||
@ -124,8 +134,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string method = params.get("method");
|
||||
std::string dictionary_id = params.get("dictionary_id");
|
||||
const String & method = params.get("method");
|
||||
const String & dictionary_id = params.get("dictionary_id");
|
||||
|
||||
LOG_TRACE(log, "Library method: '{}', dictionary id: {}", method, dictionary_id);
|
||||
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
|
||||
@ -141,7 +151,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string from_dictionary_id = params.get("from_dictionary_id");
|
||||
const String & from_dictionary_id = params.get("from_dictionary_id");
|
||||
bool cloned = false;
|
||||
cloned = ExternalDictionaryLibraryHandlerFactory::instance().clone(from_dictionary_id, dictionary_id);
|
||||
|
||||
@ -166,7 +176,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string library_path = params.get("library_path");
|
||||
const String & library_path = params.get("library_path");
|
||||
|
||||
if (!params.has("library_settings"))
|
||||
{
|
||||
@ -174,10 +184,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & settings_string = params.get("library_settings");
|
||||
const String & settings_string = params.get("library_settings");
|
||||
|
||||
LOG_DEBUG(log, "Parsing library settings from binary string");
|
||||
std::vector<std::string> library_settings = parseNamesFromBinary(settings_string);
|
||||
std::vector<String> library_settings = parseNamesFromBinary(settings_string);
|
||||
|
||||
/// Needed for library dictionary
|
||||
if (!params.has("attributes_names"))
|
||||
@ -186,10 +196,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & attributes_string = params.get("attributes_names");
|
||||
const String & attributes_string = params.get("attributes_names");
|
||||
|
||||
LOG_DEBUG(log, "Parsing attributes names from binary string");
|
||||
std::vector<std::string> attributes_names = parseNamesFromBinary(attributes_string);
|
||||
std::vector<String> attributes_names = parseNamesFromBinary(attributes_string);
|
||||
|
||||
/// Needed to parse block from binary string format
|
||||
if (!params.has("sample_block"))
|
||||
@ -197,7 +207,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
processError(response, "No 'sample_block' in request URL");
|
||||
return;
|
||||
}
|
||||
std::string sample_block_string = params.get("sample_block");
|
||||
String sample_block_string = params.get("sample_block");
|
||||
|
||||
std::shared_ptr<Block> sample_block;
|
||||
try
|
||||
@ -297,7 +307,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string requested_block_string = params.get("requested_block_sample");
|
||||
String requested_block_string = params.get("requested_block_sample");
|
||||
|
||||
std::shared_ptr<Block> requested_sample_block;
|
||||
try
|
||||
@ -332,7 +342,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_WARNING(log, "Unknown library method: '{}'", method);
|
||||
processError(response, "Unknown library method '" + method + "'");
|
||||
LOG_ERROR(log, "Unknown library method: '{}'", method);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
@ -362,6 +373,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
@ -369,6 +381,7 @@ ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExi
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
try
|
||||
@ -382,7 +395,7 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
|
||||
return;
|
||||
}
|
||||
|
||||
std::string dictionary_id = params.get("dictionary_id");
|
||||
const String & dictionary_id = params.get("dictionary_id");
|
||||
|
||||
auto library_handler = ExternalDictionaryLibraryHandlerFactory::instance().get(dictionary_id);
|
||||
|
||||
@ -399,4 +412,230 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
|
||||
}
|
||||
|
||||
|
||||
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
|
||||
size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
, log(&Poco::Logger::get("CatBoostLibraryBridgeRequestHandler"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
HTMLForm params(getContext()->getSettingsRef(), request);
|
||||
|
||||
size_t version;
|
||||
|
||||
if (!params.has("version"))
|
||||
version = 0; /// assumed version for too old servers which do not send a version
|
||||
else
|
||||
{
|
||||
const String & version_str = params.get("version");
|
||||
if (!tryParse(version, version_str))
|
||||
{
|
||||
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (version != LIBRARY_BRIDGE_PROTOCOL_VERSION)
|
||||
{
|
||||
/// backwards compatibility is considered unnecessary for now, just let the user know that the server and the bridge must be upgraded together
|
||||
processError(response, "Server and library-bridge have different versions: '" + std::to_string(version) + "' vs. '" + std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION) + "'");
|
||||
return;
|
||||
}
|
||||
if (!params.has("method"))
|
||||
{
|
||||
processError(response, "No 'method' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & method = params.get("method");
|
||||
|
||||
LOG_TRACE(log, "Library method: '{}'", method);
|
||||
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
|
||||
|
||||
try
|
||||
{
|
||||
if (method == "catboost_list")
|
||||
{
|
||||
ExternalModelInfos model_infos = CatBoostLibraryHandlerFactory::instance().getModelInfos();
|
||||
|
||||
writeIntBinary(static_cast<UInt64>(model_infos.size()), out);
|
||||
|
||||
for (const auto & info : model_infos)
|
||||
{
|
||||
writeStringBinary(info.model_path, out);
|
||||
writeStringBinary(info.model_type, out);
|
||||
|
||||
UInt64 t = std::chrono::system_clock::to_time_t(info.loading_start_time);
|
||||
writeIntBinary(t, out);
|
||||
|
||||
t = info.loading_duration.count();
|
||||
writeIntBinary(t, out);
|
||||
|
||||
}
|
||||
}
|
||||
else if (method == "catboost_removeModel")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
|
||||
if (!params.has("model_path"))
|
||||
{
|
||||
processError(response, "No 'model_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
CatBoostLibraryHandlerFactory::instance().removeModel(model_path);
|
||||
|
||||
String res = "1";
|
||||
writeStringBinary(res, out);
|
||||
}
|
||||
else if (method == "catboost_removeAllModels")
|
||||
{
|
||||
CatBoostLibraryHandlerFactory::instance().removeAllModels();
|
||||
|
||||
String res = "1";
|
||||
writeStringBinary(res, out);
|
||||
}
|
||||
else if (method == "catboost_GetTreeCount")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
|
||||
if (!params.has("library_path"))
|
||||
{
|
||||
processError(response, "No 'library_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & library_path = params.get("library_path");
|
||||
|
||||
if (!params.has("model_path"))
|
||||
{
|
||||
processError(response, "No 'model_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().tryGetModel(model_path, library_path, /*create_if_not_found*/ true);
|
||||
size_t tree_count = catboost_handler->getTreeCount();
|
||||
writeIntBinary(tree_count, out);
|
||||
}
|
||||
else if (method == "catboost_libEvaluate")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
|
||||
if (!params.has("model_path"))
|
||||
{
|
||||
processError(response, "No 'model_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
if (!params.has("data"))
|
||||
{
|
||||
processError(response, "No 'data' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & data = params.get("data");
|
||||
|
||||
ReadBufferFromString string_read_buf(data);
|
||||
NativeReader deserializer(string_read_buf, /*server_revision*/ 0);
|
||||
Block block_read = deserializer.read();
|
||||
|
||||
Columns col_ptrs = block_read.getColumns();
|
||||
ColumnRawPtrs col_raw_ptrs;
|
||||
for (const auto & p : col_ptrs)
|
||||
col_raw_ptrs.push_back(&*p);
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().tryGetModel(model_path, "DummyLibraryPath", /*create_if_not_found*/ false);
|
||||
|
||||
if (!catboost_handler)
|
||||
{
|
||||
processError(response, "CatBoost library is not loaded for model '" + model_path + "'. Please try again.");
|
||||
return;
|
||||
}
|
||||
|
||||
ColumnPtr res_col = catboost_handler->evaluate(col_raw_ptrs);
|
||||
|
||||
DataTypePtr res_col_type = std::make_shared<DataTypeFloat64>();
|
||||
String res_col_name = "res_col";
|
||||
|
||||
ColumnsWithTypeAndName res_cols_with_type_and_name = {{res_col, res_col_type, res_col_name}};
|
||||
|
||||
Block block_write(res_cols_with_type_and_name);
|
||||
NativeWriter serializer{out, /*client_revision*/ 0, block_write};
|
||||
serializer.write(block_write);
|
||||
}
|
||||
else
|
||||
{
|
||||
processError(response, "Unknown library method '" + method + "'");
|
||||
LOG_ERROR(log, "Unknown library method: '{}'", method);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
auto message = getCurrentExceptionMessage(true);
|
||||
LOG_ERROR(log, "Failed to process request. Error: {}", message);
|
||||
|
||||
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR, message); // can't call process_error, because of too soon response sending
|
||||
try
|
||||
{
|
||||
writeStringBinary(message, out);
|
||||
out.finalize();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log);
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
out.finalize();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
, log(&Poco::Logger::get("CatBoostLibraryBridgeExistsHandler"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void CatBoostLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
try
|
||||
{
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
HTMLForm params(getContext()->getSettingsRef(), request);
|
||||
|
||||
String res = "1";
|
||||
|
||||
setResponseDefaultHeaders(response, keep_alive_timeout);
|
||||
LOG_TRACE(log, "Sending ping response: {}", res);
|
||||
response.sendBuffer(res.data(), res.size());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException("PingHandler");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,9 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Server/HTTP/HTTPRequestHandler.h>
|
||||
#include <Common/logger_useful.h>
|
||||
#include "ExternalDictionaryLibraryHandler.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -26,11 +25,12 @@ public:
|
||||
private:
|
||||
static constexpr inline auto FORMAT = "RowBinary";
|
||||
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
size_t keep_alive_timeout;
|
||||
};
|
||||
|
||||
|
||||
// Handler for checking if the external dictionary library is loaded (used for handshake)
|
||||
class ExternalDictionaryLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
@ -43,4 +43,47 @@ private:
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
|
||||
/// Handler for requests to catboost library. The call protocol is as follows:
|
||||
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge. It contains a library path (e.g /home/user/libcatboost.so) and
|
||||
/// a model path (e.g. /home/user/model.bin). This loads the catboost library handler associated with the model path, then executes
|
||||
/// GetTreeCount() on the library handler and sends the result back to the server.
|
||||
/// (2) Send "catboost_Evaluate" from the server to the bridge. It contains a model path and the features to run the interference on. Step
|
||||
/// (2) is called multiple times (once per chunk) by the server.
|
||||
///
|
||||
/// We would ideally like to have steps (1) and (2) in one atomic handler but can't because the evaluation on the server side is divided
|
||||
/// into two dependent phases: FunctionCatBoostEvaluate::getReturnTypeImpl() and ::executeImpl(). So the model may in principle be unloaded
|
||||
/// from the library-bridge between steps (1) and (2). Step (2) checks if that is the case and fails gracefully. This is okay because that
|
||||
/// situation considered exceptional and rare.
|
||||
///
|
||||
/// An update of a model is performed by unloading it. The first call to "catboost_GetTreeCount" brings it into memory again.
|
||||
///
|
||||
/// Further handlers are provided for unloading a specific model, for unloading all models or for retrieving information about the loaded
|
||||
/// models for display in a system view.
|
||||
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
CatBoostLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
|
||||
|
||||
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
|
||||
|
||||
private:
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
|
||||
// Handler for pinging the library-bridge for catboost access (used for handshake)
|
||||
class CatBoostLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
|
||||
|
||||
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
|
||||
|
||||
private:
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -51,7 +51,6 @@
|
||||
#include <Interpreters/DNSCacheUpdater.h>
|
||||
#include <Interpreters/DatabaseCatalog.h>
|
||||
#include <Interpreters/ExternalDictionariesLoader.h>
|
||||
#include <Interpreters/ExternalModelsLoader.h>
|
||||
#include <Interpreters/ProcessList.h>
|
||||
#include <Interpreters/loadMetadata.h>
|
||||
#include <Interpreters/UserDefinedSQLObjectsLoader.h>
|
||||
@ -1158,7 +1157,6 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
global_context->setExternalAuthenticatorsConfig(*config);
|
||||
|
||||
global_context->loadOrReloadDictionaries(*config);
|
||||
global_context->loadOrReloadModels(*config);
|
||||
global_context->loadOrReloadUserDefinedExecutableFunctions(*config);
|
||||
|
||||
global_context->setRemoteHostFilter(*config);
|
||||
@ -1739,17 +1737,6 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
throw;
|
||||
}
|
||||
|
||||
/// try to load models immediately, throw on error and die
|
||||
try
|
||||
{
|
||||
global_context->loadOrReloadModels(config());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log, "Caught exception while loading dictionaries.");
|
||||
throw;
|
||||
}
|
||||
|
||||
/// try to load user defined executable functions, throw on error and die
|
||||
try
|
||||
{
|
||||
|
194
src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp
Normal file
194
src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp
Normal file
@ -0,0 +1,194 @@
|
||||
#include "CatBoostLibraryBridgeHelper.h"
|
||||
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/escapeForFileName.h>
|
||||
#include <Core/Block.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Formats/NativeReader.h>
|
||||
#include <Formats/NativeWriter.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <Poco/Net/HTTPRequest.h>
|
||||
|
||||
#include <random>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(
|
||||
ContextPtr context_,
|
||||
std::optional<String> model_path_,
|
||||
std::optional<String> library_path_)
|
||||
: LibraryBridgeHelper(context_->getGlobalContext())
|
||||
, model_path(model_path_)
|
||||
, library_path(library_path_)
|
||||
{
|
||||
}
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const
|
||||
{
|
||||
auto uri = createBaseURI();
|
||||
uri.setPath(PING_HANDLER);
|
||||
return uri;
|
||||
}
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::getMainURI() const
|
||||
{
|
||||
auto uri = createBaseURI();
|
||||
uri.setPath(MAIN_HANDLER);
|
||||
return uri;
|
||||
}
|
||||
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::createRequestURI(const String & method) const
|
||||
{
|
||||
auto uri = getMainURI();
|
||||
uri.addQueryParameter("version", std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION));
|
||||
uri.addQueryParameter("method", method);
|
||||
return uri;
|
||||
}
|
||||
|
||||
bool CatBoostLibraryBridgeHelper::bridgeHandShake()
|
||||
{
|
||||
String result;
|
||||
try
|
||||
{
|
||||
ReadWriteBufferFromHTTP buf(getPingURI(), Poco::Net::HTTPRequest::HTTP_GET, {}, http_timeouts, credentials);
|
||||
readString(result, buf);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result != "1")
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected message from library bridge: {}. Check that bridge and server have the same version.", result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ExternalModelInfos CatBoostLibraryBridgeHelper::listModels()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_LIST_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[](std::ostream &) {},
|
||||
http_timeouts, credentials);
|
||||
|
||||
ExternalModelInfos result;
|
||||
|
||||
UInt64 num_rows;
|
||||
readIntBinary(num_rows, buf);
|
||||
|
||||
for (UInt64 i = 0; i < num_rows; ++i)
|
||||
{
|
||||
ExternalModelInfo info;
|
||||
|
||||
readStringBinary(info.model_path, buf);
|
||||
readStringBinary(info.model_type, buf);
|
||||
|
||||
UInt64 t;
|
||||
readIntBinary(t, buf);
|
||||
info.loading_start_time = std::chrono::system_clock::from_time_t(t);
|
||||
|
||||
readIntBinary(t, buf);
|
||||
info.loading_duration = std::chrono::milliseconds(t);
|
||||
|
||||
result.push_back(info);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void CatBoostLibraryBridgeHelper::removeModel()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
assert(model_path);
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_REMOVEMODEL_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[this](std::ostream & os)
|
||||
{
|
||||
os << "model_path=" << escapeForFileName(*model_path);
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
String result;
|
||||
readStringBinary(result, buf);
|
||||
assert(result == "1");
|
||||
}
|
||||
|
||||
void CatBoostLibraryBridgeHelper::removeAllModels()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_REMOVEALLMODELS_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[](std::ostream &){},
|
||||
http_timeouts, credentials);
|
||||
|
||||
String result;
|
||||
readStringBinary(result, buf);
|
||||
assert(result == "1");
|
||||
}
|
||||
|
||||
size_t CatBoostLibraryBridgeHelper::getTreeCount()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
assert(model_path && library_path);
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_GETTREECOUNT_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[this](std::ostream & os)
|
||||
{
|
||||
os << "library_path=" << escapeForFileName(*library_path) << "&";
|
||||
os << "model_path=" << escapeForFileName(*model_path);
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
size_t result;
|
||||
readIntBinary(result, buf);
|
||||
return result;
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns)
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
WriteBufferFromOwnString string_write_buf;
|
||||
Block block(columns);
|
||||
NativeWriter serializer(string_write_buf, /*client_revision*/ 0, block);
|
||||
serializer.write(block);
|
||||
|
||||
assert(model_path);
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_LIB_EVALUATE_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[this, serialized = string_write_buf.str()](std::ostream & os)
|
||||
{
|
||||
os << "model_path=" << escapeForFileName(*model_path) << "&";
|
||||
os << "data=" << escapeForFileName(serialized);
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
NativeReader deserializer(buf, /*server_revision*/ 0);
|
||||
Block block_read = deserializer.read();
|
||||
|
||||
return block_read.getColumns()[0];
|
||||
}
|
||||
|
||||
}
|
53
src/BridgeHelper/CatBoostLibraryBridgeHelper.h
Normal file
53
src/BridgeHelper/CatBoostLibraryBridgeHelper.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
|
||||
#include <BridgeHelper/LibraryBridgeHelper.h>
|
||||
#include <Common/ExternalModelInfo.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
#include <IO/ReadWriteBufferFromHTTP.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Poco/URI.h>
|
||||
#include <optional>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class CatBoostLibraryBridgeHelper : public LibraryBridgeHelper
|
||||
{
|
||||
public:
|
||||
static constexpr inline auto PING_HANDLER = "/catboost_ping";
|
||||
static constexpr inline auto MAIN_HANDLER = "/catboost_request";
|
||||
|
||||
explicit CatBoostLibraryBridgeHelper(
|
||||
ContextPtr context_,
|
||||
std::optional<String> model_path_ = std::nullopt,
|
||||
std::optional<String> library_path_ = std::nullopt);
|
||||
|
||||
ExternalModelInfos listModels();
|
||||
|
||||
void removeModel(); /// requires model_path
|
||||
void removeAllModels();
|
||||
|
||||
size_t getTreeCount(); /// requires model_path and library_path
|
||||
ColumnPtr evaluate(const ColumnsWithTypeAndName & columns); /// requires model_path
|
||||
|
||||
protected:
|
||||
Poco::URI getPingURI() const override;
|
||||
|
||||
Poco::URI getMainURI() const override;
|
||||
|
||||
bool bridgeHandShake() override;
|
||||
|
||||
private:
|
||||
static constexpr inline auto CATBOOST_LIST_METHOD = "catboost_list";
|
||||
static constexpr inline auto CATBOOST_REMOVEMODEL_METHOD = "catboost_removeModel";
|
||||
static constexpr inline auto CATBOOST_REMOVEALLMODELS_METHOD = "catboost_removeAllModels";
|
||||
static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount";
|
||||
static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate";
|
||||
|
||||
Poco::URI createRequestURI(const String & method) const;
|
||||
|
||||
const std::optional<String> model_path;
|
||||
const std::optional<String> library_path;
|
||||
};
|
||||
|
||||
}
|
@ -12,8 +12,8 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Common base class for XDBC and Library bridge helpers.
|
||||
/// Contains helper methods to check/start bridge sync.
|
||||
/// Base class for server-side bridge helpers, e.g. xdbc-bridge and library-bridge.
|
||||
/// Contains helper methods to check/start bridge sync
|
||||
class IBridgeHelper: protected WithContext
|
||||
{
|
||||
|
||||
|
@ -176,10 +176,10 @@ static void tryLogCurrentExceptionImpl(Poco::Logger * logger, const std::string
|
||||
|
||||
void tryLogCurrentException(const char * log_name, const std::string & start_of_message)
|
||||
{
|
||||
/// Under high memory pressure, any new allocation will definitelly lead
|
||||
/// to MEMORY_LIMIT_EXCEEDED exception.
|
||||
/// Under high memory pressure, new allocations throw a
|
||||
/// MEMORY_LIMIT_EXCEEDED exception.
|
||||
///
|
||||
/// And in this case the exception will not be logged, so let's block the
|
||||
/// In this case the exception will not be logged, so let's block the
|
||||
/// MemoryTracker until the exception will be logged.
|
||||
LockMemoryExceptionInThread lock_memory_tracker(VariableContext::Global);
|
||||
|
||||
@ -189,8 +189,8 @@ void tryLogCurrentException(const char * log_name, const std::string & start_of_
|
||||
|
||||
void tryLogCurrentException(Poco::Logger * logger, const std::string & start_of_message)
|
||||
{
|
||||
/// Under high memory pressure, any new allocation will definitelly lead
|
||||
/// to MEMORY_LIMIT_EXCEEDED exception.
|
||||
/// Under high memory pressure, new allocations throw a
|
||||
/// MEMORY_LIMIT_EXCEEDED exception.
|
||||
///
|
||||
/// And in this case the exception will not be logged, so let's block the
|
||||
/// MemoryTracker until the exception will be logged.
|
||||
|
20
src/Common/ExternalModelInfo.h
Normal file
20
src/Common/ExternalModelInfo.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <base/types.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Details about external machine learning model, used by clickhouse-server and clickhouse-library-bridge
|
||||
struct ExternalModelInfo
|
||||
{
|
||||
String model_path;
|
||||
String model_type;
|
||||
std::chrono::system_clock::time_point loading_start_time; /// serialized as std::time_t
|
||||
std::chrono::milliseconds loading_duration; /// serialized as UInt64
|
||||
};
|
||||
|
||||
using ExternalModelInfos = std::vector<ExternalModelInfo>;
|
||||
|
||||
}
|
@ -1,18 +1,18 @@
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
|
||||
#include <base/range.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/ExternalModelsLoader.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
|
||||
#include <BridgeHelper/IBridgeHelper.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/Context_fwd.h>
|
||||
|
||||
|
||||
@ -21,66 +21,80 @@ namespace DB
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int FILE_DOESNT_EXIST;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
}
|
||||
|
||||
class ExternalModelsLoader;
|
||||
|
||||
|
||||
/// Evaluate external model.
|
||||
/// First argument - model name, the others - model arguments.
|
||||
/// * for CatBoost model - float features first, then categorical
|
||||
/// Result - Float64.
|
||||
class FunctionModelEvaluate final : public IFunction
|
||||
/// Evaluate CatBoost model.
|
||||
/// - Arguments: float features first, then categorical features.
|
||||
/// - Result: Float64.
|
||||
class FunctionCatBoostEvaluate final : public IFunction, WithContext
|
||||
{
|
||||
private:
|
||||
mutable std::unique_ptr<CatBoostLibraryBridgeHelper> bridge_helper;
|
||||
|
||||
public:
|
||||
static constexpr auto name = "modelEvaluate";
|
||||
static constexpr auto name = "catboostEvaluate";
|
||||
|
||||
static FunctionPtr create(ContextPtr context)
|
||||
{
|
||||
return std::make_shared<FunctionModelEvaluate>(context->getExternalModelsLoader());
|
||||
}
|
||||
|
||||
explicit FunctionModelEvaluate(const ExternalModelsLoader & models_loader_)
|
||||
: models_loader(models_loader_) {}
|
||||
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionCatBoostEvaluate>(context_); }
|
||||
|
||||
explicit FunctionCatBoostEvaluate(ContextPtr context_) : WithContext(context_) {}
|
||||
String getName() const override { return name; }
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
|
||||
|
||||
bool isDeterministic() const override { return false; }
|
||||
|
||||
bool useDefaultImplementationForNulls() const override { return false; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
|
||||
void initBridge(const ColumnConst * name_col) const
|
||||
{
|
||||
String library_path = getContext()->getConfigRef().getString("catboost_lib_path");
|
||||
if (!std::filesystem::exists(library_path))
|
||||
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load library {}: file doesn't exist", library_path);
|
||||
|
||||
String model_path = name_col->getValue<String>();
|
||||
if (!std::filesystem::exists(model_path))
|
||||
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load model {}: file doesn't exist", model_path);
|
||||
|
||||
bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), model_path, library_path);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeFromLibraryBridge() const
|
||||
{
|
||||
size_t tree_count = bridge_helper->getTreeCount();
|
||||
auto type = std::make_shared<DataTypeFloat64>();
|
||||
if (tree_count == 1)
|
||||
return type;
|
||||
|
||||
DataTypes types(tree_count, type);
|
||||
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception("Function " + getName() + " expects at least 2 arguments",
|
||||
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION);
|
||||
throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} expects at least 2 arguments", getName());
|
||||
|
||||
if (!isString(arguments[0].type))
|
||||
throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
|
||||
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of first argument of function {}, expected a string.", arguments[0].type->getName(), getName());
|
||||
|
||||
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
|
||||
if (!name_col)
|
||||
throw Exception("First argument of function " + getName() + " must be a constant string",
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
|
||||
|
||||
initBridge(name_col);
|
||||
|
||||
auto type = getReturnTypeFromLibraryBridge();
|
||||
|
||||
bool has_nullable = false;
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
has_nullable = has_nullable || arguments[i].type->isNullable();
|
||||
|
||||
auto model = models_loader.getModel(name_col->getValue<String>());
|
||||
auto type = model->getReturnType();
|
||||
|
||||
if (has_nullable)
|
||||
{
|
||||
if (const auto * tuple = typeid_cast<const DataTypeTuple *>(type.get()))
|
||||
@ -98,31 +112,25 @@ public:
|
||||
return type;
|
||||
}
|
||||
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
|
||||
{
|
||||
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
|
||||
if (!name_col)
|
||||
throw Exception("First argument of function " + getName() + " must be a constant string",
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
auto model = models_loader.getModel(name_col->getValue<String>());
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
|
||||
|
||||
ColumnRawPtrs column_ptrs;
|
||||
Columns materialized_columns;
|
||||
ColumnPtr null_map;
|
||||
|
||||
column_ptrs.reserve(arguments.size());
|
||||
for (auto arg : collections::range(1, arguments.size()))
|
||||
ColumnsWithTypeAndName feature_arguments(arguments.begin() + 1, arguments.end());
|
||||
for (auto & arg : feature_arguments)
|
||||
{
|
||||
const auto & column = arguments[arg].column;
|
||||
column_ptrs.push_back(column.get());
|
||||
if (auto full_column = column->convertToFullColumnIfConst())
|
||||
if (auto full_column = arg.column->convertToFullColumnIfConst())
|
||||
{
|
||||
materialized_columns.push_back(full_column);
|
||||
column_ptrs.back() = full_column.get();
|
||||
arg.column = full_column;
|
||||
}
|
||||
if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(*column_ptrs.back()))
|
||||
if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(&*arg.column))
|
||||
{
|
||||
if (!null_map)
|
||||
null_map = col_nullable->getNullMapColumnPtr();
|
||||
@ -140,11 +148,12 @@ public:
|
||||
null_map = std::move(mut_null_map);
|
||||
}
|
||||
|
||||
column_ptrs.back() = &col_nullable->getNestedColumn();
|
||||
arg.column = col_nullable->getNestedColumn().getPtr();
|
||||
arg.type = static_cast<const DataTypeNullable &>(*arg.type).getNestedType();
|
||||
}
|
||||
}
|
||||
|
||||
auto res = model->evaluate(column_ptrs);
|
||||
auto res = bridge_helper->evaluate(feature_arguments);
|
||||
|
||||
if (null_map)
|
||||
{
|
||||
@ -162,15 +171,12 @@ public:
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
const ExternalModelsLoader & models_loader;
|
||||
};
|
||||
|
||||
|
||||
REGISTER_FUNCTION(ExternalModels)
|
||||
REGISTER_FUNCTION(CatBoostEvaluate)
|
||||
{
|
||||
factory.registerFunction<FunctionModelEvaluate>();
|
||||
factory.registerFunction<FunctionCatBoostEvaluate>();
|
||||
}
|
||||
|
||||
}
|
@ -1,525 +0,0 @@
|
||||
#include "CatBoostModel.h"
|
||||
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
#include <mutex>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <Common/SharedLibrary.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int CANNOT_LOAD_CATBOOST_MODEL;
|
||||
extern const int CANNOT_APPLY_CATBOOST_MODEL;
|
||||
}
|
||||
|
||||
/// CatBoost wrapper interface functions.
|
||||
class CatBoostWrapperAPI
|
||||
{
|
||||
public:
|
||||
using ModelCalcerHandle = void;
|
||||
|
||||
ModelCalcerHandle * (* ModelCalcerCreate)(); // NOLINT
|
||||
|
||||
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer); // NOLINT
|
||||
|
||||
const char * (* GetErrorString)(); // NOLINT
|
||||
|
||||
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename); // NOLINT
|
||||
|
||||
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount, // NOLINT
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount, // NOLINT
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
const char *** catFeatures, size_t catFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount, // NOLINT
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
const int ** catFeatures, size_t catFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
int (* GetStringCatFeatureHash)(const char * data, size_t size); // NOLINT
|
||||
int (* GetIntegerCatFeatureHash)(uint64_t val); // NOLINT
|
||||
|
||||
size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer); // NOLINT
|
||||
size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer); // NOLINT
|
||||
size_t (* GetTreeCount)(ModelCalcerHandle* modelHandle); // NOLINT
|
||||
size_t (* GetDimensionsCount)(ModelCalcerHandle* modelHandle); // NOLINT
|
||||
|
||||
bool (* CheckModelMetadataHasKey)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize); // NOLINT
|
||||
size_t (*GetModelInfoValueSize)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize); // NOLINT
|
||||
const char* (*GetModelInfoValue)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize); // NOLINT
|
||||
};
|
||||
|
||||
|
||||
class CatBoostModelHolder
|
||||
{
|
||||
private:
|
||||
CatBoostWrapperAPI::ModelCalcerHandle * handle;
|
||||
const CatBoostWrapperAPI * api;
|
||||
public:
|
||||
explicit CatBoostModelHolder(const CatBoostWrapperAPI * api_) : api(api_) { handle = api->ModelCalcerCreate(); }
|
||||
~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
|
||||
|
||||
CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
|
||||
};
|
||||
|
||||
|
||||
/// Holds CatBoost wrapper library and provides wrapper interface.
|
||||
class CatBoostLibHolder
|
||||
{
|
||||
public:
|
||||
explicit CatBoostLibHolder(std::string lib_path_) : lib_path(std::move(lib_path_)), lib(lib_path) { initAPI(); }
|
||||
|
||||
const CatBoostWrapperAPI & getAPI() const { return api; }
|
||||
const std::string & getCurrentPath() const { return lib_path; }
|
||||
|
||||
private:
|
||||
CatBoostWrapperAPI api;
|
||||
std::string lib_path;
|
||||
SharedLibrary lib;
|
||||
|
||||
void initAPI()
|
||||
{
|
||||
load(api.ModelCalcerCreate, "ModelCalcerCreate");
|
||||
load(api.ModelCalcerDelete, "ModelCalcerDelete");
|
||||
load(api.GetErrorString, "GetErrorString");
|
||||
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
|
||||
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
|
||||
load(api.CalcModelPrediction, "CalcModelPrediction");
|
||||
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
|
||||
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
|
||||
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
|
||||
load(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
|
||||
load(api.GetCatFeaturesCount, "GetCatFeaturesCount");
|
||||
tryLoad(api.CheckModelMetadataHasKey, "CheckModelMetadataHasKey");
|
||||
tryLoad(api.GetModelInfoValueSize, "GetModelInfoValueSize");
|
||||
tryLoad(api.GetModelInfoValue, "GetModelInfoValue");
|
||||
tryLoad(api.GetTreeCount, "GetTreeCount");
|
||||
tryLoad(api.GetDimensionsCount, "GetDimensionsCount");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(T& func, const std::string & name) { func = lib.get<T>(name); }
|
||||
|
||||
template <typename T>
|
||||
void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
|
||||
};
|
||||
|
||||
std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
|
||||
{
|
||||
static std::shared_ptr<CatBoostLibHolder> ptr;
|
||||
static std::mutex mutex;
|
||||
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (!ptr || ptr->getCurrentPath() != lib_path)
|
||||
ptr = std::make_shared<CatBoostLibHolder>(lib_path);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
class CatBoostModelImpl
|
||||
{
|
||||
public:
|
||||
CatBoostModelImpl(const CatBoostWrapperAPI * api_, const std::string & model_path) : api(api_)
|
||||
{
|
||||
handle = std::make_unique<CatBoostModelHolder>(api);
|
||||
if (!handle)
|
||||
{
|
||||
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
|
||||
"Cannot create CatBoost model: {}",
|
||||
api->GetErrorString());
|
||||
}
|
||||
if (!api->LoadFullModelFromFile(handle->get(), model_path.c_str()))
|
||||
{
|
||||
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
|
||||
"Cannot load CatBoost model: {}",
|
||||
api->GetErrorString());
|
||||
}
|
||||
|
||||
float_features_count = api->GetFloatFeaturesCount(handle->get());
|
||||
cat_features_count = api->GetCatFeaturesCount(handle->get());
|
||||
tree_count = 1;
|
||||
if (api->GetDimensionsCount)
|
||||
tree_count = api->GetDimensionsCount(handle->get());
|
||||
}
|
||||
|
||||
ColumnPtr evaluate(const ColumnRawPtrs & columns) const
|
||||
{
|
||||
if (columns.empty())
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got empty columns list for CatBoost model.");
|
||||
|
||||
if (columns.size() != float_features_count + cat_features_count)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Number of columns is different with number of features: columns size {} float features size {} + cat features size {}",
|
||||
columns.size(),
|
||||
float_features_count,
|
||||
cat_features_count);
|
||||
|
||||
for (size_t i = 0; i < float_features_count; ++i)
|
||||
{
|
||||
if (!columns[i]->isNumeric())
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric to make float feature.", i);
|
||||
}
|
||||
}
|
||||
|
||||
bool cat_features_are_strings = true;
|
||||
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
|
||||
{
|
||||
const auto * column = columns[i];
|
||||
if (column->isNumeric())
|
||||
{
|
||||
cat_features_are_strings = false;
|
||||
}
|
||||
else if (!(typeid_cast<const ColumnString *>(column)
|
||||
|| typeid_cast<const ColumnFixedString *>(column)))
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric or string.", i);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = evalImpl(columns, cat_features_are_strings);
|
||||
|
||||
if (tree_count == 1)
|
||||
return result;
|
||||
|
||||
size_t column_size = columns.front()->size();
|
||||
auto * result_buf = result->getData().data();
|
||||
|
||||
/// Multiple trees case. Copy data to several columns.
|
||||
MutableColumns mutable_columns(tree_count);
|
||||
std::vector<Float64 *> column_ptrs(tree_count);
|
||||
for (size_t i = 0; i < tree_count; ++i)
|
||||
{
|
||||
auto col = ColumnFloat64::create(column_size);
|
||||
column_ptrs[i] = col->getData().data();
|
||||
mutable_columns[i] = std::move(col);
|
||||
}
|
||||
|
||||
Float64 * data = result_buf;
|
||||
for (size_t row = 0; row < column_size; ++row)
|
||||
{
|
||||
for (size_t i = 0; i < tree_count; ++i)
|
||||
{
|
||||
*column_ptrs[i] = *data;
|
||||
++column_ptrs[i];
|
||||
++data;
|
||||
}
|
||||
}
|
||||
|
||||
return ColumnTuple::create(std::move(mutable_columns));
|
||||
}
|
||||
|
||||
size_t getFloatFeaturesCount() const { return float_features_count; }
|
||||
size_t getCatFeaturesCount() const { return cat_features_count; }
|
||||
size_t getTreeCount() const { return tree_count; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<CatBoostModelHolder> handle;
|
||||
const CatBoostWrapperAPI * api;
|
||||
size_t float_features_count;
|
||||
size_t cat_features_count;
|
||||
size_t tree_count;
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
template <typename T>
|
||||
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
|
||||
{
|
||||
size_t size = column->size();
|
||||
FieldVisitorConvertToNumber<T> visitor;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
/// TODO: Replace with column visitor.
|
||||
Field field;
|
||||
column->get(i, field);
|
||||
*buffer = applyVisitor(visitor, field);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
static void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
|
||||
static PODArray<char> placeFixedStringColumn(
|
||||
const ColumnFixedString & column, const char ** buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column.size();
|
||||
size_t str_size = column.getN();
|
||||
PODArray<char> data(size * (str_size + 1));
|
||||
char * data_ptr = data.data();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto ref = column.getDataAt(i);
|
||||
memcpy(data_ptr, ref.data, ref.size);
|
||||
data_ptr[ref.size] = 0;
|
||||
*buffer = data_ptr;
|
||||
data_ptr += ref.size + 1;
|
||||
buffer += features_count;
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
|
||||
template <typename T>
|
||||
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns,
|
||||
size_t offset, size_t size, const T** buffer) const
|
||||
{
|
||||
if (size == 0)
|
||||
return nullptr;
|
||||
size_t column_size = columns[offset]->size();
|
||||
auto data_column = ColumnVector<T>::create(size * column_size);
|
||||
T * data = data_column->getData().data();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (column->isNumeric())
|
||||
placeColumnAsNumber(column, data + i, size);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*buffer = data;
|
||||
++buffer;
|
||||
data += size;
|
||||
}
|
||||
|
||||
return data_column;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns data which was used for fixed string columns.
|
||||
/// Buffer should contains column->size() values, each value contains size strings.
|
||||
static std::vector<PODArray<char>> placeStringColumns(
|
||||
const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer)
|
||||
{
|
||||
if (size == 0)
|
||||
return {};
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
|
||||
placeStringColumn(*column_string, buffer + i, size);
|
||||
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
|
||||
else
|
||||
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Calc hash for string cat feature at ps positions.
|
||||
template <typename Column>
|
||||
void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
|
||||
{
|
||||
size_t column_size = column->size();
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
auto ref = column->getDataAt(j);
|
||||
const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
|
||||
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
|
||||
{
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// buffer contains column->size() rows and size columns.
|
||||
/// For int cat features calc hash inplace.
|
||||
/// For string cat features calc hash from column rows.
|
||||
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer) const
|
||||
{
|
||||
if (size == 0)
|
||||
return;
|
||||
size_t column_size = columns[offset]->size();
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
|
||||
calcStringHashes(column_string, i, buffer);
|
||||
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
calcStringHashes(column_fixed_string, i, buffer);
|
||||
else
|
||||
calcIntHashes(column_size, i, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
|
||||
void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
|
||||
size_t column_size) const
|
||||
{
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*cat_features = buffer;
|
||||
++cat_features;
|
||||
buffer += cat_features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
|
||||
/// * CalcModelPredictionFlat if no cat features
|
||||
/// * CalcModelPrediction if all cat features are strings
|
||||
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
|
||||
ColumnFloat64::MutablePtr evalImpl(
|
||||
const ColumnRawPtrs & columns,
|
||||
bool cat_features_are_strings) const
|
||||
{
|
||||
std::string error_msg = "Error occurred while applying CatBoost model: ";
|
||||
size_t column_size = columns.front()->size();
|
||||
|
||||
auto result = ColumnFloat64::create(column_size * tree_count);
|
||||
auto * result_buf = result->getData().data();
|
||||
|
||||
if (!column_size)
|
||||
return result;
|
||||
|
||||
/// Prepare float features.
|
||||
PODArray<const float *> float_features(column_size);
|
||||
auto * float_features_buf = float_features.data();
|
||||
/// Store all float data into single column. float_features is a list of pointers to it.
|
||||
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
|
||||
|
||||
if (cat_features_count == 0)
|
||||
{
|
||||
if (!api->CalcModelPredictionFlat(handle->get(), column_size,
|
||||
float_features_buf, float_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
|
||||
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Prepare cat features.
|
||||
if (cat_features_are_strings)
|
||||
{
|
||||
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
|
||||
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
|
||||
PODArray<const char **> cat_features(column_size);
|
||||
auto * cat_features_buf = cat_features.data();
|
||||
|
||||
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size);
|
||||
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
|
||||
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
|
||||
cat_features_count, cat_features_holder.data());
|
||||
|
||||
if (!api->CalcModelPrediction(handle->get(), column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
PODArray<const int *> cat_features(column_size);
|
||||
auto * cat_features_buf = cat_features.data();
|
||||
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
|
||||
cat_features_count, cat_features_buf);
|
||||
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
|
||||
if (!api->CalcModelPredictionWithHashedCatFeatures(
|
||||
handle->get(), column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
CatBoostModel::CatBoostModel(std::string name_, std::string model_path_, std::string lib_path_,
|
||||
const ExternalLoadableLifetime & lifetime_)
|
||||
: name(std::move(name_)), model_path(std::move(model_path_)), lib_path(std::move(lib_path_)), lifetime(lifetime_)
|
||||
{
|
||||
api_provider = getCatBoostWrapperHolder(lib_path);
|
||||
api = &api_provider->getAPI();
|
||||
model = std::make_unique<CatBoostModelImpl>(api, model_path);
|
||||
}
|
||||
|
||||
CatBoostModel::~CatBoostModel() = default;
|
||||
|
||||
size_t CatBoostModel::getFloatFeaturesCount() const
|
||||
{
|
||||
return model->getFloatFeaturesCount();
|
||||
}
|
||||
|
||||
size_t CatBoostModel::getCatFeaturesCount() const
|
||||
{
|
||||
return model->getCatFeaturesCount();
|
||||
}
|
||||
|
||||
size_t CatBoostModel::getTreeCount() const
|
||||
{
|
||||
return model->getTreeCount();
|
||||
}
|
||||
|
||||
DataTypePtr CatBoostModel::getReturnType() const
|
||||
{
|
||||
size_t tree_count = getTreeCount();
|
||||
auto type = std::make_shared<DataTypeFloat64>();
|
||||
if (tree_count == 1)
|
||||
return type;
|
||||
|
||||
DataTypes types(tree_count, type);
|
||||
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostModel::evaluate(const ColumnRawPtrs & columns) const
|
||||
{
|
||||
if (!model)
|
||||
throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
return model->evaluate(columns);
|
||||
}
|
||||
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/IExternalLoadable.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class CatBoostLibHolder;
|
||||
class CatBoostWrapperAPI;
|
||||
class CatBoostModelImpl;
|
||||
|
||||
class IDataType;
|
||||
using DataTypePtr = std::shared_ptr<const IDataType>;
|
||||
|
||||
/// General ML model evaluator interface.
|
||||
class IMLModel : public IExternalLoadable
|
||||
{
|
||||
public:
|
||||
IMLModel() = default;
|
||||
virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0;
|
||||
virtual std::string getTypeName() const = 0;
|
||||
virtual DataTypePtr getReturnType() const = 0;
|
||||
virtual ~IMLModel() override = default;
|
||||
};
|
||||
|
||||
class CatBoostModel : public IMLModel
|
||||
{
|
||||
public:
|
||||
CatBoostModel(std::string name, std::string model_path,
|
||||
std::string lib_path, const ExternalLoadableLifetime & lifetime);
|
||||
|
||||
~CatBoostModel() override;
|
||||
|
||||
ColumnPtr evaluate(const ColumnRawPtrs & columns) const override;
|
||||
std::string getTypeName() const override { return "catboost"; }
|
||||
|
||||
size_t getFloatFeaturesCount() const;
|
||||
size_t getCatFeaturesCount() const;
|
||||
size_t getTreeCount() const;
|
||||
DataTypePtr getReturnType() const override;
|
||||
|
||||
/// IExternalLoadable interface.
|
||||
|
||||
const ExternalLoadableLifetime & getLifetime() const override { return lifetime; }
|
||||
|
||||
std::string getLoadableName() const override { return name; }
|
||||
|
||||
bool supportUpdates() const override { return true; }
|
||||
|
||||
bool isModified() const override { return true; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<CatBoostModel>(name, model_path, lib_path, lifetime);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string name;
|
||||
std::string model_path;
|
||||
std::string lib_path;
|
||||
ExternalLoadableLifetime lifetime;
|
||||
std::shared_ptr<CatBoostLibHolder> api_provider;
|
||||
const CatBoostWrapperAPI * api;
|
||||
|
||||
std::unique_ptr<CatBoostModelImpl> model;
|
||||
|
||||
void init();
|
||||
};
|
||||
|
||||
}
|
@ -52,7 +52,6 @@
|
||||
#include <Interpreters/EmbeddedDictionaries.h>
|
||||
#include <Interpreters/ExternalDictionariesLoader.h>
|
||||
#include <Interpreters/ExternalUserDefinedExecutableFunctionsLoader.h>
|
||||
#include <Interpreters/ExternalModelsLoader.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Interpreters/ProcessList.h>
|
||||
#include <Interpreters/InterserverCredentials.h>
|
||||
@ -153,7 +152,6 @@ struct ContextSharedPart : boost::noncopyable
|
||||
mutable std::mutex embedded_dictionaries_mutex;
|
||||
mutable std::mutex external_dictionaries_mutex;
|
||||
mutable std::mutex external_user_defined_executable_functions_mutex;
|
||||
mutable std::mutex external_models_mutex;
|
||||
/// Separate mutex for storage policies. During server startup we may
|
||||
/// initialize some important storages (system logs with MergeTree engine)
|
||||
/// under context lock.
|
||||
@ -191,9 +189,7 @@ struct ContextSharedPart : boost::noncopyable
|
||||
mutable std::unique_ptr<EmbeddedDictionaries> embedded_dictionaries; /// Metrica's dictionaries. Have lazy initialization.
|
||||
mutable std::unique_ptr<ExternalDictionariesLoader> external_dictionaries_loader;
|
||||
mutable std::unique_ptr<ExternalUserDefinedExecutableFunctionsLoader> external_user_defined_executable_functions_loader;
|
||||
mutable std::unique_ptr<ExternalModelsLoader> external_models_loader;
|
||||
|
||||
ExternalLoaderXMLConfigRepository * external_models_config_repository = nullptr;
|
||||
scope_guard models_repository_guard;
|
||||
|
||||
ExternalLoaderXMLConfigRepository * external_dictionaries_config_repository = nullptr;
|
||||
@ -359,8 +355,6 @@ struct ContextSharedPart : boost::noncopyable
|
||||
external_dictionaries_loader->enablePeriodicUpdates(false);
|
||||
if (external_user_defined_executable_functions_loader)
|
||||
external_user_defined_executable_functions_loader->enablePeriodicUpdates(false);
|
||||
if (external_models_loader)
|
||||
external_models_loader->enablePeriodicUpdates(false);
|
||||
|
||||
Session::shutdownNamedSessions();
|
||||
|
||||
@ -391,7 +385,6 @@ struct ContextSharedPart : boost::noncopyable
|
||||
std::unique_ptr<EmbeddedDictionaries> delete_embedded_dictionaries;
|
||||
std::unique_ptr<ExternalDictionariesLoader> delete_external_dictionaries_loader;
|
||||
std::unique_ptr<ExternalUserDefinedExecutableFunctionsLoader> delete_external_user_defined_executable_functions_loader;
|
||||
std::unique_ptr<ExternalModelsLoader> delete_external_models_loader;
|
||||
std::unique_ptr<BackgroundSchedulePool> delete_buffer_flush_schedule_pool;
|
||||
std::unique_ptr<BackgroundSchedulePool> delete_schedule_pool;
|
||||
std::unique_ptr<BackgroundSchedulePool> delete_distributed_schedule_pool;
|
||||
@ -430,7 +423,6 @@ struct ContextSharedPart : boost::noncopyable
|
||||
delete_embedded_dictionaries = std::move(embedded_dictionaries);
|
||||
delete_external_dictionaries_loader = std::move(external_dictionaries_loader);
|
||||
delete_external_user_defined_executable_functions_loader = std::move(external_user_defined_executable_functions_loader);
|
||||
delete_external_models_loader = std::move(external_models_loader);
|
||||
delete_buffer_flush_schedule_pool = std::move(buffer_flush_schedule_pool);
|
||||
delete_schedule_pool = std::move(schedule_pool);
|
||||
delete_distributed_schedule_pool = std::move(distributed_schedule_pool);
|
||||
@ -458,7 +450,6 @@ struct ContextSharedPart : boost::noncopyable
|
||||
delete_embedded_dictionaries.reset();
|
||||
delete_external_dictionaries_loader.reset();
|
||||
delete_external_user_defined_executable_functions_loader.reset();
|
||||
delete_external_models_loader.reset();
|
||||
delete_ddl_worker.reset();
|
||||
delete_buffer_flush_schedule_pool.reset();
|
||||
delete_schedule_pool.reset();
|
||||
@ -1476,48 +1467,6 @@ ExternalUserDefinedExecutableFunctionsLoader & Context::getExternalUserDefinedEx
|
||||
return *shared->external_user_defined_executable_functions_loader;
|
||||
}
|
||||
|
||||
const ExternalModelsLoader & Context::getExternalModelsLoader() const
|
||||
{
|
||||
return const_cast<Context *>(this)->getExternalModelsLoader();
|
||||
}
|
||||
|
||||
ExternalModelsLoader & Context::getExternalModelsLoader()
|
||||
{
|
||||
std::lock_guard lock(shared->external_models_mutex);
|
||||
return getExternalModelsLoaderUnlocked();
|
||||
}
|
||||
|
||||
ExternalModelsLoader & Context::getExternalModelsLoaderUnlocked()
|
||||
{
|
||||
if (!shared->external_models_loader)
|
||||
shared->external_models_loader =
|
||||
std::make_unique<ExternalModelsLoader>(getGlobalContext());
|
||||
return *shared->external_models_loader;
|
||||
}
|
||||
|
||||
void Context::loadOrReloadModels(const Poco::Util::AbstractConfiguration & config)
|
||||
{
|
||||
auto patterns_values = getMultipleValuesFromConfig(config, "", "models_config");
|
||||
std::unordered_set<std::string> patterns(patterns_values.begin(), patterns_values.end());
|
||||
|
||||
std::lock_guard lock(shared->external_models_mutex);
|
||||
|
||||
auto & external_models_loader = getExternalModelsLoaderUnlocked();
|
||||
|
||||
if (shared->external_models_config_repository)
|
||||
{
|
||||
shared->external_models_config_repository->updatePatterns(patterns);
|
||||
external_models_loader.reloadConfig(shared->external_models_config_repository->getName());
|
||||
return;
|
||||
}
|
||||
|
||||
auto app_path = getPath();
|
||||
auto config_path = getConfigRef().getString("config-file", "config.xml");
|
||||
auto repository = std::make_unique<ExternalLoaderXMLConfigRepository>(app_path, config_path, patterns);
|
||||
shared->external_models_config_repository = repository.get();
|
||||
shared->models_repository_guard = external_models_loader.addConfigRepository(std::move(repository));
|
||||
}
|
||||
|
||||
EmbeddedDictionaries & Context::getEmbeddedDictionariesImpl(const bool throw_on_error) const
|
||||
{
|
||||
std::lock_guard lock(shared->embedded_dictionaries_mutex);
|
||||
|
@ -53,7 +53,6 @@ class AccessRightsElements;
|
||||
enum class RowPolicyFilterType;
|
||||
class EmbeddedDictionaries;
|
||||
class ExternalDictionariesLoader;
|
||||
class ExternalModelsLoader;
|
||||
class ExternalUserDefinedExecutableFunctionsLoader;
|
||||
class InterserverCredentials;
|
||||
using InterserverCredentialsPtr = std::shared_ptr<const InterserverCredentials>;
|
||||
@ -645,19 +644,15 @@ public:
|
||||
|
||||
const EmbeddedDictionaries & getEmbeddedDictionaries() const;
|
||||
const ExternalDictionariesLoader & getExternalDictionariesLoader() const;
|
||||
const ExternalModelsLoader & getExternalModelsLoader() const;
|
||||
const ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoader() const;
|
||||
EmbeddedDictionaries & getEmbeddedDictionaries();
|
||||
ExternalDictionariesLoader & getExternalDictionariesLoader();
|
||||
ExternalDictionariesLoader & getExternalDictionariesLoaderUnlocked();
|
||||
ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoader();
|
||||
ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoaderUnlocked();
|
||||
ExternalModelsLoader & getExternalModelsLoader();
|
||||
ExternalModelsLoader & getExternalModelsLoaderUnlocked();
|
||||
void tryCreateEmbeddedDictionaries(const Poco::Util::AbstractConfiguration & config) const;
|
||||
void loadOrReloadDictionaries(const Poco::Util::AbstractConfiguration & config);
|
||||
void loadOrReloadUserDefinedExecutableFunctions(const Poco::Util::AbstractConfiguration & config);
|
||||
void loadOrReloadModels(const Poco::Util::AbstractConfiguration & config);
|
||||
|
||||
#if USE_NLP
|
||||
SynonymsExtensions & getSynonymsExtensions() const;
|
||||
|
@ -1,41 +0,0 @@
|
||||
#include <Interpreters/ExternalModelsLoader.h>
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int INVALID_CONFIG_PARAMETER;
|
||||
}
|
||||
|
||||
|
||||
ExternalModelsLoader::ExternalModelsLoader(ContextPtr context_)
|
||||
: ExternalLoader("external model", &Poco::Logger::get("ExternalModelsLoader")), WithContext(context_)
|
||||
{
|
||||
setConfigSettings({"model", "name", {}, {}});
|
||||
enablePeriodicUpdates(true);
|
||||
}
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> ExternalModelsLoader::create(
|
||||
const std::string & name, const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & config_prefix, const std::string & /* repository_name */) const
|
||||
{
|
||||
String type = config.getString(config_prefix + ".type");
|
||||
ExternalLoadableLifetime lifetime(config, config_prefix + ".lifetime");
|
||||
|
||||
/// TODO: add models factory.
|
||||
if (type == "catboost")
|
||||
{
|
||||
return std::make_unique<CatBoostModel>(
|
||||
name, config.getString(config_prefix + ".path"),
|
||||
getContext()->getConfigRef().getString("catboost_dynamic_library_path"),
|
||||
lifetime
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Unknown model type: " + type, ErrorCodes::INVALID_CONFIG_PARAMETER);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/CatBoostModel.h>
|
||||
#include <Interpreters/Context_fwd.h>
|
||||
#include <Interpreters/ExternalLoader.h>
|
||||
#include <Common/logger_useful.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Manages user-defined models.
|
||||
class ExternalModelsLoader : public ExternalLoader, WithContext
|
||||
{
|
||||
public:
|
||||
using ModelPtr = std::shared_ptr<const IMLModel>;
|
||||
|
||||
/// Models will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||
explicit ExternalModelsLoader(ContextPtr context_);
|
||||
|
||||
ModelPtr getModel(const std::string & model_name) const
|
||||
{
|
||||
return std::static_pointer_cast<const IMLModel>(load(model_name));
|
||||
}
|
||||
|
||||
void reloadModel(const std::string & model_name) const
|
||||
{
|
||||
loadOrReload(model_name);
|
||||
}
|
||||
|
||||
protected:
|
||||
LoadablePtr create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & config_prefix, const std::string & repository_name) const override;
|
||||
|
||||
friend class StorageSystemModels;
|
||||
};
|
||||
|
||||
}
|
@ -12,7 +12,6 @@
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/DatabaseCatalog.h>
|
||||
#include <Interpreters/ExternalDictionariesLoader.h>
|
||||
#include <Interpreters/ExternalModelsLoader.h>
|
||||
#include <Interpreters/ExternalUserDefinedExecutableFunctionsLoader.h>
|
||||
#include <Interpreters/EmbeddedDictionaries.h>
|
||||
#include <Interpreters/ActionLocksManager.h>
|
||||
@ -36,6 +35,7 @@
|
||||
#include <Interpreters/ProcessorsProfileLog.h>
|
||||
#include <Interpreters/JIT/CompiledExpressionCache.h>
|
||||
#include <Interpreters/TransactionLog.h>
|
||||
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
|
||||
#include <Access/ContextAccess.h>
|
||||
#include <Access/Common/AllowedClientHosts.h>
|
||||
#include <Databases/IDatabase.h>
|
||||
@ -387,17 +387,15 @@ BlockIO InterpreterSystemQuery::execute()
|
||||
case Type::RELOAD_MODEL:
|
||||
{
|
||||
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
|
||||
|
||||
auto & external_models_loader = system_context->getExternalModelsLoader();
|
||||
external_models_loader.reloadModel(query.target_model);
|
||||
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), query.target_model);
|
||||
bridge_helper->removeModel();
|
||||
break;
|
||||
}
|
||||
case Type::RELOAD_MODELS:
|
||||
{
|
||||
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
|
||||
|
||||
auto & external_models_loader = system_context->getExternalModelsLoader();
|
||||
external_models_loader.reloadAllTriedToLoad();
|
||||
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext());
|
||||
bridge_helper->removeAllModels();
|
||||
break;
|
||||
}
|
||||
case Type::RELOAD_FUNCTION:
|
||||
|
@ -1,11 +1,11 @@
|
||||
#include <Storages/System/StorageSystemModels.h>
|
||||
#include <Common/ExternalModelInfo.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeEnum.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/ExternalModelsLoader.h>
|
||||
#include <Interpreters/CatBoostModel.h>
|
||||
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -14,45 +14,24 @@ namespace DB
|
||||
NamesAndTypesList StorageSystemModels::getNamesAndTypes()
|
||||
{
|
||||
return {
|
||||
{ "name", std::make_shared<DataTypeString>() },
|
||||
{ "status", std::make_shared<DataTypeEnum8>(getStatusEnumAllPossibleValues()) },
|
||||
{ "origin", std::make_shared<DataTypeString>() },
|
||||
{ "model_path", std::make_shared<DataTypeString>() },
|
||||
{ "type", std::make_shared<DataTypeString>() },
|
||||
{ "loading_start_time", std::make_shared<DataTypeDateTime>() },
|
||||
{ "loading_duration", std::make_shared<DataTypeFloat32>() },
|
||||
//{ "creation_time", std::make_shared<DataTypeDateTime>() },
|
||||
{ "last_exception", std::make_shared<DataTypeString>() },
|
||||
};
|
||||
}
|
||||
|
||||
void StorageSystemModels::fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo &) const
|
||||
{
|
||||
const auto & external_models_loader = context->getExternalModelsLoader();
|
||||
auto load_results = external_models_loader.getLoadResults();
|
||||
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(context);
|
||||
ExternalModelInfos infos = bridge_helper->listModels();
|
||||
|
||||
for (const auto & load_result : load_results)
|
||||
for (const auto & info : infos)
|
||||
{
|
||||
res_columns[0]->insert(load_result.name);
|
||||
res_columns[1]->insert(static_cast<Int8>(load_result.status));
|
||||
res_columns[2]->insert(load_result.config ? load_result.config->path : "");
|
||||
|
||||
if (load_result.object)
|
||||
{
|
||||
const auto model_ptr = std::static_pointer_cast<const IMLModel>(load_result.object);
|
||||
res_columns[3]->insert(model_ptr->getTypeName());
|
||||
}
|
||||
else
|
||||
{
|
||||
res_columns[3]->insertDefault();
|
||||
}
|
||||
|
||||
res_columns[4]->insert(static_cast<UInt64>(std::chrono::system_clock::to_time_t(load_result.loading_start_time)));
|
||||
res_columns[5]->insert(std::chrono::duration_cast<std::chrono::duration<float>>(load_result.loading_duration).count());
|
||||
|
||||
if (load_result.exception)
|
||||
res_columns[6]->insert(getExceptionMessage(load_result.exception, false));
|
||||
else
|
||||
res_columns[6]->insertDefault();
|
||||
res_columns[0]->insert(info.model_path);
|
||||
res_columns[1]->insert(info.model_type);
|
||||
res_columns[2]->insert(static_cast<UInt64>(std::chrono::system_clock::to_time_t(info.loading_start_time)));
|
||||
res_columns[3]->insert(std::chrono::duration_cast<std::chrono::duration<float>>(info.loading_duration).count());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -763,7 +763,6 @@
|
||||
"MINUTE"
|
||||
"MM"
|
||||
"mod"
|
||||
"modelEvaluate"
|
||||
"MODIFY"
|
||||
"MODIFY COLUMN"
|
||||
"MODIFY ORDER BY"
|
||||
|
@ -469,7 +469,6 @@
|
||||
"subtractSeconds"
|
||||
"alphaTokens"
|
||||
"negate"
|
||||
"modelEvaluate"
|
||||
"file"
|
||||
"roundAge"
|
||||
"MACStringToOUI"
|
||||
|
@ -0,0 +1,3 @@
|
||||
<clickhouse>
|
||||
<catboost_lib_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_lib_path>
|
||||
</clickhouse>
|
BIN
tests/integration/test_catboost_evaluate/model/amazon_model.bin
Normal file
BIN
tests/integration/test_catboost_evaluate/model/amazon_model.bin
Normal file
Binary file not shown.
402
tests/integration/test_catboost_evaluate/test.py
Normal file
402
tests/integration/test_catboost_evaluate/test.py
Normal file
@ -0,0 +1,402 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
from helpers.cluster import ClickHouseCluster
|
||||
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
|
||||
instance = cluster.add_instance(
|
||||
"instance", stay_alive=True, main_configs=["config/models_config.xml"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ch_cluster():
|
||||
try:
|
||||
cluster.start()
|
||||
|
||||
os.system(
|
||||
"docker cp {local} {cont_id}:{dist}".format(
|
||||
local=os.path.join(SCRIPT_DIR, "model/."),
|
||||
cont_id=instance.docker_id,
|
||||
dist="/etc/clickhouse-server/model",
|
||||
)
|
||||
)
|
||||
instance.restart_clickhouse()
|
||||
|
||||
yield cluster
|
||||
|
||||
finally:
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# simple_model.bin has 2 float features and 9 categorical features
|
||||
|
||||
|
||||
def testConstantFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def testNonConstantFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
instance.query(
|
||||
"CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;"
|
||||
)
|
||||
instance.query("INSERT INTO T VALUES(0, 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11) from T;"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
|
||||
|
||||
def testModelPathIsNotAConstString(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert (
|
||||
"Illegal type UInt8 of first argument of function catboostEvaluate, expected a string"
|
||||
in err
|
||||
)
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
instance.query("CREATE TABLE T(ID UInt32, A String) ENGINE MergeTree ORDER BY ID")
|
||||
instance.query("INSERT INTO T VALUES(0, 'test');")
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate(A, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) FROM T;"
|
||||
)
|
||||
assert (
|
||||
"First argument of function catboostEvaluate must be a constant string" in err
|
||||
)
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
|
||||
|
||||
def testWrongNumberOfFeatureArguments(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');"
|
||||
)
|
||||
assert "Function catboostEvaluate expects at least 2 arguments" in err
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2);"
|
||||
)
|
||||
assert (
|
||||
"Number of columns is different with number of features: columns size 2 float features size 2 + cat features size 9"
|
||||
in err
|
||||
)
|
||||
|
||||
|
||||
def testFloatFeatureMustBeNumeric(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert "Column 1 should be numeric to make float feature" in err
|
||||
|
||||
|
||||
def testCategoricalFeatureMustBeNumericOrString(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);"
|
||||
)
|
||||
assert "Column 7 should be numeric or string" in err
|
||||
|
||||
|
||||
def testOnLowCardinalityFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
# same but on domain-compressed data
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def testOnNullableFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
# Actual NULLs are disallowed
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL));"
|
||||
)
|
||||
assert "Column 0 should be numeric to make float feature" in err
|
||||
|
||||
|
||||
def testInvalidLibraryPath(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
# temporarily move library elsewhere
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/libcatboostmodel.so /etc/clickhouse-server/model/nonexistant.so",
|
||||
]
|
||||
)
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert (
|
||||
"Can't load library /etc/clickhouse-server/model/libcatboostmodel.so: file doesn't exist"
|
||||
in err
|
||||
)
|
||||
|
||||
# restore
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/nonexistant.so /etc/clickhouse-server/model/libcatboostmodel.so",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def testInvalidModelPath(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert "Can't load model : file doesn't exist" in err
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('model_non_existant.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert "Can't load model model_non_existant.bin: file doesn't exist" in err
|
||||
|
||||
|
||||
def testRecoveryAfterCrash(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
instance.exec_in_container(
|
||||
["bash", "-c", "kill -9 `pidof clickhouse-library-bridge`"], user="root"
|
||||
)
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# amazon_model.bin has 0 float features and 9 categorical features
|
||||
|
||||
|
||||
def testAmazonModelSingleRow(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
expected = "0.7774665009089274\n"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def testAmazonModelManyRows(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query("drop table if exists amazon")
|
||||
|
||||
result = instance.query(
|
||||
"create table amazon ( DATE Date materialized today(), ACTION UInt8, RESOURCE UInt32, MGR_ID UInt32, ROLE_ROLLUP_1 UInt32, ROLE_ROLLUP_2 UInt32, ROLE_DEPTNAME UInt32, ROLE_TITLE UInt32, ROLE_FAMILY_DESC UInt32, ROLE_FAMILY UInt32, ROLE_CODE UInt32) engine = MergeTree order by DATE"
|
||||
)
|
||||
|
||||
result = instance.query(
|
||||
"insert into amazon select number % 256, number, number, number, number, number, number, number, number, number from numbers(7500)"
|
||||
)
|
||||
|
||||
# First compute prediction, then as a very crude way to fingerprint and compare the result: sum and floor
|
||||
# (the focus is to test that the exchange of large result sets between the server and the bridge works)
|
||||
result = instance.query(
|
||||
"SELECT floor(sum(catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', RESOURCE, MGR_ID, ROLE_ROLLUP_1, ROLE_ROLLUP_2, ROLE_DEPTNAME, ROLE_TITLE, ROLE_FAMILY_DESC, ROLE_FAMILY, ROLE_CODE))) FROM amazon"
|
||||
)
|
||||
|
||||
expected = "5834\n"
|
||||
assert result == expected
|
||||
|
||||
result = instance.query("drop table if exists amazon")
|
||||
|
||||
|
||||
def testModelUpdate(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
|
||||
result = instance.query(query)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
# simulate an update of the model: temporarily move the amazon model in place of the simple model
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/simple_model.bin.bak",
|
||||
]
|
||||
)
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/amazon_model.bin /etc/clickhouse-server/model/simple_model.bin",
|
||||
]
|
||||
)
|
||||
|
||||
# unload simple model
|
||||
result = instance.query(
|
||||
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
|
||||
)
|
||||
|
||||
# load the simple-model-camouflaged amazon model
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
expected = "0.7774665009089274\n"
|
||||
assert result == expected
|
||||
|
||||
# restore
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/amazon_model.bin",
|
||||
]
|
||||
)
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def testSystemModelsAndModelRefresh(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
# check model system view
|
||||
result = instance.query("select * from system.models")
|
||||
expected = ""
|
||||
assert result == expected
|
||||
|
||||
# load simple model
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
# check model system view with one model loaded
|
||||
result = instance.query("select * from system.models")
|
||||
assert result.count("\n") == 1
|
||||
expected = "/etc/clickhouse-server/model/simple_model.bin"
|
||||
assert expected in result
|
||||
|
||||
# load amazon model
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
expected = "0.7774665009089274\n"
|
||||
assert result == expected
|
||||
|
||||
# check model system view with one model loaded
|
||||
result = instance.query("select * from system.models")
|
||||
assert result.count("\n") == 2
|
||||
expected = "/etc/clickhouse-server/model/simple_model.bin"
|
||||
assert expected in result
|
||||
expected = "/etc/clickhouse-server/model/amazon_model.bin"
|
||||
assert expected in result
|
||||
|
||||
# unload simple model
|
||||
result = instance.query(
|
||||
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
|
||||
)
|
||||
|
||||
# check model system view, it should not display the removed model
|
||||
result = instance.query("select * from system.models")
|
||||
assert result.count("\n") == 1
|
||||
expected = "/etc/clickhouse-server/model/amazon_model.bin"
|
||||
assert expected in result
|
@ -1,3 +0,0 @@
|
||||
<clickhouse>
|
||||
<catboost_dynamic_library_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_dynamic_library_path>
|
||||
</clickhouse>
|
@ -1,2 +0,0 @@
|
||||
<clickhouse>
|
||||
</clickhouse>
|
@ -1,8 +0,0 @@
|
||||
<models>
|
||||
<model>
|
||||
<type>catboost</type>
|
||||
<name>model1</name>
|
||||
<path>/etc/clickhouse-server/model/model.bin</path>
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
@ -1,8 +0,0 @@
|
||||
<models>
|
||||
<model>
|
||||
<type>catboost</type>
|
||||
<name>model2</name>
|
||||
<path>/etc/clickhouse-server/model/model.bin</path>
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
@ -1,77 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
from helpers.cluster import ClickHouseCluster
|
||||
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
node = cluster.add_instance(
|
||||
"node",
|
||||
stay_alive=True,
|
||||
main_configs=["config/models_config.xml", "config/catboost_lib.xml"],
|
||||
)
|
||||
|
||||
|
||||
def copy_file_to_container(local_path, dist_path, container_id):
|
||||
os.system(
|
||||
"docker cp {local} {cont_id}:{dist}".format(
|
||||
local=local_path, cont_id=container_id, dist=dist_path
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
config = """<clickhouse>
|
||||
<models_config>/etc/clickhouse-server/model/{model_config}</models_config>
|
||||
</clickhouse>"""
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def started_cluster():
|
||||
try:
|
||||
cluster.start()
|
||||
|
||||
copy_file_to_container(
|
||||
os.path.join(SCRIPT_DIR, "model/."),
|
||||
"/etc/clickhouse-server/model",
|
||||
node.docker_id,
|
||||
)
|
||||
node.restart_clickhouse()
|
||||
|
||||
yield cluster
|
||||
|
||||
finally:
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def change_config(model_config):
|
||||
node.replace_config(
|
||||
"/etc/clickhouse-server/config.d/models_config.xml",
|
||||
config.format(model_config=model_config),
|
||||
)
|
||||
node.query("SYSTEM RELOAD CONFIG;")
|
||||
|
||||
|
||||
def test(started_cluster):
|
||||
if node.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
# Set config with the path to the first model.
|
||||
change_config("model_config.xml")
|
||||
|
||||
node.query("SELECT modelEvaluate('model1', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
|
||||
|
||||
# Change path to the second model in config.
|
||||
change_config("model_config2.xml")
|
||||
|
||||
# Check that the new model is loaded.
|
||||
node.query("SELECT modelEvaluate('model2', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
|
||||
|
||||
# Check that the old model was unloaded.
|
||||
node.query_and_get_error(
|
||||
"SELECT modelEvaluate('model1', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
@ -1,4 +0,0 @@
|
||||
<clickhouse>
|
||||
<catboost_dynamic_library_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_dynamic_library_path>
|
||||
<models_config>/etc/clickhouse-server/model/model_config.xml</models_config>
|
||||
</clickhouse>
|
Binary file not shown.
Binary file not shown.
@ -1,8 +0,0 @@
|
||||
<models>
|
||||
<model>
|
||||
<type>catboost</type>
|
||||
<name>titanic</name>
|
||||
<path>/etc/clickhouse-server/model/model.bin</path>
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
@ -1,48 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
from helpers.cluster import ClickHouseCluster
|
||||
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
node = cluster.add_instance(
|
||||
"node", stay_alive=True, main_configs=["config/models_config.xml"]
|
||||
)
|
||||
|
||||
|
||||
def copy_file_to_container(local_path, dist_path, container_id):
|
||||
os.system(
|
||||
"docker cp {local} {cont_id}:{dist}".format(
|
||||
local=local_path, cont_id=container_id, dist=dist_path
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def started_cluster():
|
||||
try:
|
||||
cluster.start()
|
||||
|
||||
copy_file_to_container(
|
||||
os.path.join(SCRIPT_DIR, "model/."),
|
||||
"/etc/clickhouse-server/model",
|
||||
node.docker_id,
|
||||
)
|
||||
node.restart_clickhouse()
|
||||
|
||||
yield cluster
|
||||
|
||||
finally:
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test(started_cluster):
|
||||
if node.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
node.query("select modelEvaluate('titanic', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
|
@ -1,3 +0,0 @@
|
||||
<clickhouse>
|
||||
<catboost_dynamic_library_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_dynamic_library_path>
|
||||
</clickhouse>
|
@ -1,3 +0,0 @@
|
||||
<clickhouse>
|
||||
<models_config>/etc/clickhouse-server/model/model_config.xml</models_config>
|
||||
</clickhouse>
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,8 +0,0 @@
|
||||
<models>
|
||||
<model>
|
||||
<type>catboost</type>
|
||||
<name>model</name>
|
||||
<path>/etc/clickhouse-server/model/model.cbm</path>
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
@ -1,132 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
from helpers.cluster import ClickHouseCluster
|
||||
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
node = cluster.add_instance(
|
||||
"node",
|
||||
stay_alive=True,
|
||||
main_configs=["config/models_config.xml", "config/catboost_lib.xml"],
|
||||
)
|
||||
|
||||
|
||||
def copy_file_to_container(local_path, dist_path, container_id):
|
||||
os.system(
|
||||
"docker cp {local} {cont_id}:{dist}".format(
|
||||
local=local_path, cont_id=container_id, dist=dist_path
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def started_cluster():
|
||||
try:
|
||||
cluster.start()
|
||||
|
||||
copy_file_to_container(
|
||||
os.path.join(SCRIPT_DIR, "model/."),
|
||||
"/etc/clickhouse-server/model",
|
||||
node.docker_id,
|
||||
)
|
||||
node.query("CREATE TABLE binary (x UInt64, y UInt64) ENGINE = TinyLog()")
|
||||
node.query("INSERT INTO binary VALUES (1, 1), (1, 0), (0, 1), (0, 0)")
|
||||
|
||||
node.restart_clickhouse()
|
||||
|
||||
yield cluster
|
||||
|
||||
finally:
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_model_reload(started_cluster):
|
||||
if node.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
node.exec_in_container(
|
||||
["bash", "-c", "rm -f /etc/clickhouse-server/model/model.cbm"]
|
||||
)
|
||||
node.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"ln /etc/clickhouse-server/model/conjunction.cbm /etc/clickhouse-server/model/model.cbm",
|
||||
]
|
||||
)
|
||||
node.query("SYSTEM RELOAD MODEL model")
|
||||
|
||||
result = node.query(
|
||||
"""
|
||||
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
|
||||
SELECT if(probability > 0.5, 1, 0) FROM binary;
|
||||
"""
|
||||
)
|
||||
assert result == "1\n0\n0\n0\n"
|
||||
|
||||
node.exec_in_container(["bash", "-c", "rm /etc/clickhouse-server/model/model.cbm"])
|
||||
node.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"ln /etc/clickhouse-server/model/disjunction.cbm /etc/clickhouse-server/model/model.cbm",
|
||||
]
|
||||
)
|
||||
node.query("SYSTEM RELOAD MODEL model")
|
||||
|
||||
result = node.query(
|
||||
"""
|
||||
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
|
||||
SELECT if(probability > 0.5, 1, 0) FROM binary;
|
||||
"""
|
||||
)
|
||||
assert result == "1\n1\n1\n0\n"
|
||||
|
||||
|
||||
def test_models_reload(started_cluster):
|
||||
if node.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
node.exec_in_container(
|
||||
["bash", "-c", "rm -f /etc/clickhouse-server/model/model.cbm"]
|
||||
)
|
||||
node.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"ln /etc/clickhouse-server/model/conjunction.cbm /etc/clickhouse-server/model/model.cbm",
|
||||
]
|
||||
)
|
||||
node.query("SYSTEM RELOAD MODELS")
|
||||
|
||||
result = node.query(
|
||||
"""
|
||||
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
|
||||
SELECT if(probability > 0.5, 1, 0) FROM binary;
|
||||
"""
|
||||
)
|
||||
assert result == "1\n0\n0\n0\n"
|
||||
|
||||
node.exec_in_container(["bash", "-c", "rm /etc/clickhouse-server/model/model.cbm"])
|
||||
node.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"ln /etc/clickhouse-server/model/disjunction.cbm /etc/clickhouse-server/model/model.cbm",
|
||||
]
|
||||
)
|
||||
node.query("SYSTEM RELOAD MODELS")
|
||||
|
||||
result = node.query(
|
||||
"""
|
||||
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
|
||||
SELECT if(probability > 0.5, 1, 0) FROM binary;
|
||||
"""
|
||||
)
|
||||
assert result == "1\n1\n1\n0\n"
|
@ -16,7 +16,7 @@ function run_selects()
|
||||
{
|
||||
thread_num=$1
|
||||
readarray -t tables_arr < <(${CLICKHOUSE_CLIENT} -q "SELECT database || '.' || name FROM system.tables
|
||||
WHERE database in ('system', 'information_schema', 'INFORMATION_SCHEMA') and name!='zookeeper' and name!='merge_tree_metadata_cache'
|
||||
WHERE database in ('system', 'information_schema', 'INFORMATION_SCHEMA') and name!='zookeeper' and name!='merge_tree_metadata_cache' and name!='models'
|
||||
AND sipHash64(name || toString($RAND)) % $THREADS = $thread_num")
|
||||
|
||||
for t in "${tables_arr[@]}"
|
||||
|
@ -364,18 +364,6 @@ CREATE TABLE system.metrics
|
||||
)
|
||||
ENGINE = SystemMetrics
|
||||
COMMENT 'SYSTEM TABLE is built on the fly.'
|
||||
CREATE TABLE system.models
|
||||
(
|
||||
`name` String,
|
||||
`status` Enum8('NOT_LOADED' = 0, 'LOADED' = 1, 'FAILED' = 2, 'LOADING' = 3, 'FAILED_AND_RELOADING' = 4, 'LOADED_AND_RELOADING' = 5, 'NOT_EXIST' = 6),
|
||||
`origin` String,
|
||||
`type` String,
|
||||
`loading_start_time` DateTime,
|
||||
`loading_duration` Float32,
|
||||
`last_exception` String
|
||||
)
|
||||
ENGINE = SystemModels
|
||||
COMMENT 'SYSTEM TABLE is built on the fly.'
|
||||
CREATE TABLE system.mutations
|
||||
(
|
||||
`database` String,
|
||||
|
@ -45,7 +45,6 @@ show create table macros format TSVRaw;
|
||||
show create table merge_tree_settings format TSVRaw;
|
||||
show create table merges format TSVRaw;
|
||||
show create table metrics format TSVRaw;
|
||||
show create table models format TSVRaw;
|
||||
show create table mutations format TSVRaw;
|
||||
show create table numbers format TSVRaw;
|
||||
show create table numbers_mt format TSVRaw;
|
||||
|
@ -1,2 +0,0 @@
|
||||
-- This model does not exist:
|
||||
SELECT modelEvaluate('hello', 1, 2, 3); -- { serverError 36 }
|
@ -192,6 +192,7 @@ caseWithExpr
|
||||
caseWithExpression
|
||||
caseWithoutExpr
|
||||
caseWithoutExpression
|
||||
catboostEvaluate
|
||||
cbrt
|
||||
ceil
|
||||
char
|
||||
@ -475,7 +476,6 @@ min2
|
||||
minSampleSizeContinous
|
||||
minSampleSizeConversion
|
||||
minus
|
||||
modelEvaluate
|
||||
modulo
|
||||
moduloLegacy
|
||||
moduloOrZero
|
||||
|
Loading…
Reference in New Issue
Block a user