Merge pull request #40897 from ClickHouse/catboost-bridge-resurrected

Move CatBoost evaluation into clickhouse-library-bridge
This commit is contained in:
Robert Schulze 2022-09-16 13:12:09 +02:00 committed by GitHub
commit b32b02d844
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
70 changed files with 1785 additions and 1720 deletions

4
.gitignore vendored
View File

@ -58,6 +58,10 @@ cmake_install.cmake
CTestTestfile.cmake
*.a
*.o
*.so
*.dll
*.lib
*.dylib
cmake-build-*
# Python cache

View File

@ -1823,6 +1823,36 @@ Result:
Evaluate external model.
Accepts a model name and model arguments. Returns Float64.
## catboostEvaluate(path_to_model, feature_1, feature_2, …, feature_n)
Evaluate external catboost model. [CatBoost](https://catboost.ai) is an open-source gradient boosting library developed by Yandex for machine learing.
Accepts a path to a catboost model and model arguments (features). Returns Float64.
``` sql
SELECT feat1, ..., feat_n, catboostEvaluate('/path/to/model.bin', feat_1, ..., feat_n) AS prediction
FROM data_table
```
**Prerequisites**
1. Build the catboost evaluation library
Before evaluating catboost models, the `libcatboostmodel.<so|dylib>` library must be made available. See [CatBoost documentation](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html) how to compile it.
Next, specify the path to `libcatboostmodel.<so|dylib>` in the clickhouse configuration:
``` xml
<clickhouse>
...
<catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path>
...
</clickhouse>
```
2. Train a catboost model using libcatboost
See [Training and applying models](https://catboost.ai/docs/features/training.html#training) for how to train catboost models from a training data set.
## throwIf(x\[, message\[, error_code\]\])
Throw an exception if the argument is non zero.

View File

@ -30,7 +30,12 @@ SELECT name, status FROM system.dictionaries;
## RELOAD MODELS
Reloads all [CatBoost](../../guides/developer/apply-catboost-model.md) models if the configuration was updated without restarting the server.
:::note
This statement and `SYSTEM RELOAD MODEL` merely unload catboost models from the clickhouse-library-bridge. The function `catboostEvaluate()`
loads a model upon first access if it is not loaded yet.
:::
Unloads all CatBoost models.
**Syntax**
@ -40,12 +45,12 @@ SYSTEM RELOAD MODELS [ON CLUSTER cluster_name]
## RELOAD MODEL
Completely reloads a CatBoost model `model_name` if the configuration was updated without restarting the server.
Unloads a CatBoost model at `model_path`.
**Syntax**
```sql
SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] <model_name>
SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] <model_path>
```
## RELOAD FUNCTIONS

View File

@ -155,7 +155,6 @@ getting_started/index.md getting-started/index.md
getting_started/install.md getting-started/install.md
getting_started/playground.md getting-started/playground.md
getting_started/tutorial.md getting-started/tutorial.md
guides/apply_catboost_model.md guides/apply-catboost-model.md
images/column_oriented.gif images/column-oriented.gif
images/row_oriented.gif images/row-oriented.gif
interfaces/http_interface.md interfaces/http.md

View File

@ -1,241 +0,0 @@
---
slug: /ru/guides/apply-catboost-model
sidebar_position: 41
sidebar_label: "Применение модели CatBoost в ClickHouse"
---
# Применение модели CatBoost в ClickHouse {#applying-catboost-model-in-clickhouse}
[CatBoost](https://catboost.ai) — открытая программная библиотека разработанная компанией [Яндекс](https://yandex.ru/company/) для машинного обучения, которая использует схему градиентного бустинга.
С помощью этой инструкции вы научитесь применять предобученные модели в ClickHouse: в результате вы запустите вывод модели из SQL.
Чтобы применить модель CatBoost в ClickHouse:
1. [Создайте таблицу](#create-table).
2. [Вставьте данные в таблицу](#insert-data-to-table).
3. [Интегрируйте CatBoost в ClickHouse](#integrate-catboost-into-clickhouse) (Опциональный шаг).
4. [Запустите вывод модели из SQL](#run-model-inference).
Подробнее об обучении моделей в CatBoost, см. [Обучение и применение моделей](https://catboost.ai/docs/features/training.html#training).
Вы можете перегрузить модели CatBoost, если их конфигурация была обновлена, без перезагрузки сервера. Для этого используйте системные запросы [RELOAD MODEL](../sql-reference/statements/system.md#query_language-system-reload-model) и [RELOAD MODELS](../sql-reference/statements/system.md#query_language-system-reload-models).
## Перед началом работы {#prerequisites}
Если у вас еще нет [Docker](https://docs.docker.com/install/), установите его.
:::note "Примечание"
[Docker](https://www.docker.com) это программная платформа для создания контейнеров, которые изолируют установку CatBoost и ClickHouse от остальной части системы.
:::
Перед применением модели CatBoost:
**1.** Скачайте [Docker-образ](https://hub.docker.com/r/yandex/tutorial-catboost-clickhouse) из реестра:
``` bash
$ docker pull yandex/tutorial-catboost-clickhouse
```
Данный Docker-образ содержит все необходимое для запуска CatBoost и ClickHouse: код, среду выполнения, библиотеки, переменные окружения и файлы конфигурации.
**2.** Проверьте, что Docker-образ успешно скачался:
``` bash
$ docker image ls
REPOSITORY TAG IMAGE ID CREATED SIZE
yandex/tutorial-catboost-clickhouse latest 622e4d17945b 22 hours ago 1.37GB
```
**3.** Запустите Docker-контейнер основанный на данном образе:
``` bash
$ docker run -it -p 8888:8888 yandex/tutorial-catboost-clickhouse
```
## 1. Создайте таблицу {#create-table}
Чтобы создать таблицу для обучающей выборки:
**1.** Запустите клиент ClickHouse:
``` bash
$ clickhouse client
```
:::note "Примечание"
Сервер ClickHouse уже запущен внутри Docker-контейнера.
:::
**2.** Создайте таблицу в ClickHouse с помощью следующей команды:
``` sql
:) CREATE TABLE amazon_train
(
date Date MATERIALIZED today(),
ACTION UInt8,
RESOURCE UInt32,
MGR_ID UInt32,
ROLE_ROLLUP_1 UInt32,
ROLE_ROLLUP_2 UInt32,
ROLE_DEPTNAME UInt32,
ROLE_TITLE UInt32,
ROLE_FAMILY_DESC UInt32,
ROLE_FAMILY UInt32,
ROLE_CODE UInt32
)
ENGINE = MergeTree ORDER BY date
```
**3.** Выйдите из клиента ClickHouse:
``` sql
:) exit
```
## 2. Вставьте данные в таблицу {#insert-data-to-table}
Чтобы вставить данные:
**1.** Выполните следующую команду:
``` bash
$ clickhouse client --host 127.0.0.1 --query 'INSERT INTO amazon_train FORMAT CSVWithNames' < ~/amazon/train.csv
```
**2.** Запустите клиент ClickHouse:
``` bash
$ clickhouse client
```
**3.** Проверьте, что данные успешно загрузились:
``` sql
:) SELECT count() FROM amazon_train
SELECT count()
FROM amazon_train
+-count()-+
| 65538 |
+---------+
```
## 3. Интегрируйте CatBoost в ClickHouse {#integrate-catboost-into-clickhouse}
:::note "Примечание"
**Опциональный шаг.** Docker-образ содержит все необходимое для запуска CatBoost и ClickHouse.
:::
Чтобы интегрировать CatBoost в ClickHouse:
**1.** Создайте библиотеку для оценки модели.
Наиболее быстрый способ оценить модель CatBoost — это скомпилировать библиотеку `libcatboostmodel.<so|dll|dylib>`. Подробнее о том, как скомпилировать библиотеку, читайте в [документации CatBoost](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html).
**2.** Создайте в любом месте новую директорию с произвольным названием, например `data` и поместите в нее созданную библиотеку. Docker-образ уже содержит библиотеку `data/libcatboostmodel.so`.
**3.** Создайте в любом месте новую директорию для конфигурации модели с произвольным названием, например `models`.
**4.** Создайте файл конфигурации модели с произвольным названием, например `models/amazon_model.xml`.
**5.** Опишите конфигурацию модели:
``` xml
<models>
<model>
<!-- Тип модели. В настоящий момент ClickHouse предоставляет только модель catboost. -->
<type>catboost</type>
<!-- Имя модели. -->
<name>amazon</name>
<!-- Путь к обученной модели. -->
<path>/home/catboost/tutorial/catboost_model.bin</path>
<!-- Интервал обновления. -->
<lifetime>0</lifetime>
</model>
</models>
```
**6.** Добавьте в конфигурацию ClickHouse путь к CatBoost и конфигурации модели:
``` xml
<!-- Файл etc/clickhouse-server/config.d/models_config.xml. -->
<catboost_dynamic_library_path>/home/catboost/data/libcatboostmodel.so</catboost_dynamic_library_path>
<models_config>/home/catboost/models/*_model.xml</models_config>
```
:::note "Примечание"
Вы можете позднее изменить путь к конфигурации модели CatBoost без перезагрузки сервера.
:::
## 4. Запустите вывод модели из SQL {#run-model-inference}
Для тестирования модели запустите клиент ClickHouse `$ clickhouse client`.
Проверьте, что модель работает:
``` sql
:) SELECT
modelEvaluate('amazon',
RESOURCE,
MGR_ID,
ROLE_ROLLUP_1,
ROLE_ROLLUP_2,
ROLE_DEPTNAME,
ROLE_TITLE,
ROLE_FAMILY_DESC,
ROLE_FAMILY,
ROLE_CODE) > 0 AS prediction,
ACTION AS target
FROM amazon_train
LIMIT 10
```
:::note "Примечание"
Функция [modelEvaluate](../sql-reference/functions/other-functions.md#function-modelevaluate) возвращает кортежи (tuple) с исходными прогнозами по классам для моделей с несколькими классами.
:::
Спрогнозируйте вероятность:
``` sql
:) SELECT
modelEvaluate('amazon',
RESOURCE,
MGR_ID,
ROLE_ROLLUP_1,
ROLE_ROLLUP_2,
ROLE_DEPTNAME,
ROLE_TITLE,
ROLE_FAMILY_DESC,
ROLE_FAMILY,
ROLE_CODE) AS prediction,
1. / (1 + exp(-prediction)) AS probability,
ACTION AS target
FROM amazon_train
LIMIT 10
```
:::note "Примечание"
Подробнее про функцию [exp()](../sql-reference/functions/math-functions.md).
:::
Посчитайте логистическую функцию потерь (LogLoss) на всей выборке:
``` sql
:) SELECT -avg(tg * log(prob) + (1 - tg) * log(1 - prob)) AS logloss
FROM
(
SELECT
modelEvaluate('amazon',
RESOURCE,
MGR_ID,
ROLE_ROLLUP_1,
ROLE_ROLLUP_2,
ROLE_DEPTNAME,
ROLE_TITLE,
ROLE_FAMILY_DESC,
ROLE_FAMILY,
ROLE_CODE) AS prediction,
1. / (1. + exp(-prediction)) AS prob,
ACTION AS tg
FROM amazon_train
)
```
:::note "Примечание"
Подробнее про функции [avg()](../sql-reference/aggregate-functions/reference/avg.md#agg_function-avg), [log()](../sql-reference/functions/math-functions.md).
:::

View File

@ -7,5 +7,3 @@ sidebar_label: "Руководства"
# Руководства {#rukovodstva}
Подробные пошаговые инструкции, которые помогут вам решать различные задачи с помощью ClickHouse.
- [Применение модели CatBoost в ClickHouse](apply-catboost-model.md)

View File

@ -29,7 +29,12 @@ SELECT name, status FROM system.dictionaries;
## RELOAD MODELS {#query_language-system-reload-models}
Перегружает все модели [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse), если их конфигурация была обновлена, без перезагрузки сервера.
:::note
Это утверждение и `SYSTEM RELOAD MODEL` просто выгружают модели catboost из clickhouse-library-bridge. Функция `catboostEvaluate()`
загружает модель при первом обращении, если она еще не загружена.
:::
Разгрузите все модели CatBoost.
**Синтаксис**
@ -39,12 +44,12 @@ SYSTEM RELOAD MODELS
## RELOAD MODEL {#query_language-system-reload-model}
Полностью перегружает модель [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse) `model_name`, если ее конфигурация была обновлена, без перезагрузки сервера.
Выгружает модель CatBoost по адресу одель_путь`.
**Синтаксис**
```sql
SYSTEM RELOAD MODEL <model_name>
SYSTEM RELOAD MODEL <model_path>
```
## RELOAD FUNCTIONS {#query_language-system-reload-functions}

View File

@ -1,244 +0,0 @@
---
slug: /zh/guides/apply-catboost-model
sidebar_position: 41
sidebar_label: "\u5E94\u7528CatBoost\u6A21\u578B"
---
# 在ClickHouse中应用Catboost模型 {#applying-catboost-model-in-clickhouse}
[CatBoost](https://catboost.ai) 是一个由[Yandex](https://yandex.com/company/)开发的开源免费机器学习库。
通过本篇文档您将学会如何用SQL语句调用已经存放在Clickhouse中的预训练模型来预测数据。
为了在ClickHouse中应用CatBoost模型需要进行如下步骤
1. [创建数据表](#create-table).
2. [将数据插入到表中](#insert-data-to-table).
3. [将CatBoost集成到ClickHouse中](#integrate-catboost-into-clickhouse) (可跳过)。
4. [从SQL运行模型推断](#run-model-inference).
有关训练CatBoost模型的详细信息请参阅 [训练和模型应用](https://catboost.ai/docs/features/training.html#training).
您可以通过[RELOAD MODEL](https://clickhouse.com/docs/en/sql-reference/statements/system/#query_language-system-reload-model)与[RELOAD MODELS](https://clickhouse.com/docs/en/sql-reference/statements/system/#query_language-system-reload-models)语句来重载CatBoost模型。
## 先决条件 {#prerequisites}
请先安装 [Docker](https://docs.docker.com/install/)。
!!! note "注"
[Docker](https://www.docker.com) 是一个软件平台用户可以用Docker来创建独立于已有系统并集成了CatBoost和ClickHouse的容器。
在应用CatBoost模型之前:
**1.** 从容器仓库拉取示例docker镜像 (https://hub.docker.com/r/yandex/tutorial-catboost-clickhouse) :
``` bash
$ docker pull yandex/tutorial-catboost-clickhouse
```
此示例Docker镜像包含运行CatBoost和ClickHouse所需的所有内容代码、运行时、库、环境变量和配置文件。
**2.** 确保已成功拉取Docker镜像:
``` bash
$ docker image ls
REPOSITORY TAG IMAGE ID CREATED SIZE
yandex/tutorial-catboost-clickhouse latest 622e4d17945b 22 hours ago 1.37GB
```
**3.** 基于此镜像启动一个Docker容器:
``` bash
$ docker run -it -p 8888:8888 yandex/tutorial-catboost-clickhouse
```
## 1. 创建数据表 {#create-table}
为训练样本创建ClickHouse表:
**1.** 在交互模式下启动ClickHouse控制台客户端:
``` bash
$ clickhouse client
```
!!! note "注"
ClickHouse服务器已经在Docker容器内运行。
**2.** 使用以下命令创建表:
``` sql
:) CREATE TABLE amazon_train
(
date Date MATERIALIZED today(),
ACTION UInt8,
RESOURCE UInt32,
MGR_ID UInt32,
ROLE_ROLLUP_1 UInt32,
ROLE_ROLLUP_2 UInt32,
ROLE_DEPTNAME UInt32,
ROLE_TITLE UInt32,
ROLE_FAMILY_DESC UInt32,
ROLE_FAMILY UInt32,
ROLE_CODE UInt32
)
ENGINE = MergeTree ORDER BY date
```
**3.** 从ClickHouse控制台客户端退出:
``` sql
:) exit
```
## 2. 将数据插入到表中 {#insert-data-to-table}
插入数据:
**1.** 运行以下命令:
``` bash
$ clickhouse client --host 127.0.0.1 --query 'INSERT INTO amazon_train FORMAT CSVWithNames' < ~/amazon/train.csv
```
**2.** 在交互模式下启动ClickHouse控制台客户端:
``` bash
$ clickhouse client
```
**3.** 确保数据已上传:
``` sql
:) SELECT count() FROM amazon_train
SELECT count()
FROM amazon_train
+-count()-+
| 65538 |
+-------+
```
## 3. 将CatBoost集成到ClickHouse中 {#integrate-catboost-into-clickhouse}
!!! note "注"
**可跳过。** 示例Docker映像已经包含了运行CatBoost和ClickHouse所需的所有内容。
为了将CatBoost集成进ClickHouse需要进行如下步骤
**1.** 构建评估库。
评估CatBoost模型的最快方法是编译 `libcatboostmodel.<so|dll|dylib>` 库文件.
有关如何构建库文件的详细信息,请参阅 [CatBoost文件](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html).
**2.** 创建一个新目录(位置与名称可随意指定), 如 `data` 并将创建的库文件放入其中。 示例Docker镜像已经包含了库 `data/libcatboostmodel.so`.
**3.** 创建一个新目录来放配置模型, 如 `models`.
**4.** 创建一个模型配置文件,如 `models/amazon_model.xml`.
**5.** 修改模型配置:
``` xml
<models>
<model>
<!-- Model type. Now catboost only. -->
<type>catboost</type>
<!-- Model name. -->
<name>amazon</name>
<!-- Path to trained model. -->
<path>/home/catboost/tutorial/catboost_model.bin</path>
<!-- Update interval. -->
<lifetime>0</lifetime>
</model>
</models>
```
**6.** 将CatBoost库文件的路径和模型配置添加到ClickHouse配置:
``` xml
<!-- File etc/clickhouse-server/config.d/models_config.xml. -->
<catboost_dynamic_library_path>/home/catboost/data/libcatboostmodel.so</catboost_dynamic_library_path>
<models_config>/home/catboost/models/*_model.xml</models_config>
```
## 4. 使用SQL调用预测模型 {#run-model-inference}
为了测试模型是否正常可以使用ClickHouse客户端 `$ clickhouse client`.
让我们确保模型能正常工作:
``` sql
:) SELECT
modelEvaluate('amazon',
RESOURCE,
MGR_ID,
ROLE_ROLLUP_1,
ROLE_ROLLUP_2,
ROLE_DEPTNAME,
ROLE_TITLE,
ROLE_FAMILY_DESC,
ROLE_FAMILY,
ROLE_CODE) > 0 AS prediction,
ACTION AS target
FROM amazon_train
LIMIT 10
```
!!! note "注"
函数 [modelEvaluate](../sql-reference/functions/other-functions.md#function-modelevaluate) 会对多类别模型返回一个元组,其中包含每一类别的原始预测值。
执行预测:
``` sql
:) SELECT
modelEvaluate('amazon',
RESOURCE,
MGR_ID,
ROLE_ROLLUP_1,
ROLE_ROLLUP_2,
ROLE_DEPTNAME,
ROLE_TITLE,
ROLE_FAMILY_DESC,
ROLE_FAMILY,
ROLE_CODE) AS prediction,
1. / (1 + exp(-prediction)) AS probability,
ACTION AS target
FROM amazon_train
LIMIT 10
```
!!! note "注"
查看函数说明 [exp()](../sql-reference/functions/math-functions.md) 。
让我们计算样本的LogLoss:
``` sql
:) SELECT -avg(tg * log(prob) + (1 - tg) * log(1 - prob)) AS logloss
FROM
(
SELECT
modelEvaluate('amazon',
RESOURCE,
MGR_ID,
ROLE_ROLLUP_1,
ROLE_ROLLUP_2,
ROLE_DEPTNAME,
ROLE_TITLE,
ROLE_FAMILY_DESC,
ROLE_FAMILY,
ROLE_CODE) AS prediction,
1. / (1. + exp(-prediction)) AS prob,
ACTION AS tg
FROM amazon_train
)
```
!!! note "注"
查看函数说明 [avg()](../sql-reference/aggregate-functions/reference/avg.md#agg_function-avg) 和 [log()](../sql-reference/functions/math-functions.md) 。
[原始文章](https://clickhouse.com/docs/en/guides/apply_catboost_model/) <!--hide-->

View File

@ -9,6 +9,5 @@ sidebar_label: ClickHouse指南
列出了如何使用 Clickhouse 解决各种任务的详细说明:
- [关于简单集群设置的教程](../getting-started/tutorial.md)
- [在ClickHouse中应用CatBoost模型](apply-catboost-model.md)
[原始文章](https://clickhouse.com/docs/en/guides/) <!--hide-->

View File

@ -54,7 +54,7 @@ else ()
endif ()
if (NOT USE_MUSL)
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to Library dictionary source" ${ENABLE_CLICKHOUSE_ALL})
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to external dynamically loaded libraries" ${ENABLE_CLICKHOUSE_ALL})
endif ()
# https://presentations.clickhouse.com/matemarketing_2020/

View File

@ -1,6 +1,8 @@
include(${ClickHouse_SOURCE_DIR}/cmake/split_debug_symbols.cmake)
set (CLICKHOUSE_LIBRARY_BRIDGE_SOURCES
CatBoostLibraryHandler.cpp
CatBoostLibraryHandlerFactory.cpp
ExternalDictionaryLibraryAPI.cpp
ExternalDictionaryLibraryHandler.cpp
ExternalDictionaryLibraryHandlerFactory.cpp

View File

@ -0,0 +1,49 @@
#pragma once
#include <cstdint>
#include <cstddef>
// Function pointer typedefs and names of libcatboost.so functions used by ClickHouse
struct CatBoostLibraryAPI
{
using ModelCalcerHandle = void;
using ModelCalcerCreateFunc = ModelCalcerHandle * (*)();
static constexpr const char * ModelCalcerCreateName = "ModelCalcerCreate";
using ModelCalcerDeleteFunc = void (*)(ModelCalcerHandle *);
static constexpr const char * ModelCalcerDeleteName = "ModelCalcerDelete";
using GetErrorStringFunc = const char * (*)();
static constexpr const char * GetErrorStringName = "GetErrorString";
using LoadFullModelFromFileFunc = bool (*)(ModelCalcerHandle *, const char *);
static constexpr const char * LoadFullModelFromFileName = "LoadFullModelFromFile";
using CalcModelPredictionFlatFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, double *, size_t);
static constexpr const char * CalcModelPredictionFlatName = "CalcModelPredictionFlat";
using CalcModelPredictionFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const char ***, size_t, double *, size_t);
static constexpr const char * CalcModelPredictionName = "CalcModelPrediction";
using CalcModelPredictionWithHashedCatFeaturesFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const int **, size_t, double *, size_t);
static constexpr const char * CalcModelPredictionWithHashedCatFeaturesName = "CalcModelPredictionWithHashedCatFeatures";
using GetStringCatFeatureHashFunc = int (*)(const char *, size_t);
static constexpr const char * GetStringCatFeatureHashName = "GetStringCatFeatureHash";
using GetIntegerCatFeatureHashFunc = int (*)(uint64_t);
static constexpr const char * GetIntegerCatFeatureHashName = "GetIntegerCatFeatureHash";
using GetFloatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetFloatFeaturesCountName = "GetFloatFeaturesCount";
using GetCatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetCatFeaturesCountName = "GetCatFeaturesCount";
using GetTreeCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetTreeCountName = "GetTreeCount";
using GetDimensionsCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetDimensionsCountName = "GetDimensionsCount";
};

View File

@ -0,0 +1,389 @@
#include "CatBoostLibraryHandler.h"
#include <Columns/ColumnTuple.h>
#include <Common/FieldVisitorConvertToNumber.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int CANNOT_APPLY_CATBOOST_MODEL;
extern const int CANNOT_LOAD_CATBOOST_MODEL;
extern const int LOGICAL_ERROR;
}
CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib)
{
ModelCalcerCreate = lib.get<CatBoostLibraryAPI::ModelCalcerCreateFunc>(CatBoostLibraryAPI::ModelCalcerCreateName);
ModelCalcerDelete = lib.get<CatBoostLibraryAPI::ModelCalcerDeleteFunc>(CatBoostLibraryAPI::ModelCalcerDeleteName);
GetErrorString = lib.get<CatBoostLibraryAPI::GetErrorStringFunc>(CatBoostLibraryAPI::GetErrorStringName);
LoadFullModelFromFile = lib.get<CatBoostLibraryAPI::LoadFullModelFromFileFunc>(CatBoostLibraryAPI::LoadFullModelFromFileName);
CalcModelPredictionFlat = lib.get<CatBoostLibraryAPI::CalcModelPredictionFlatFunc>(CatBoostLibraryAPI::CalcModelPredictionFlatName);
CalcModelPrediction = lib.get<CatBoostLibraryAPI::CalcModelPredictionFunc>(CatBoostLibraryAPI::CalcModelPredictionName);
CalcModelPredictionWithHashedCatFeatures = lib.get<CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc>(CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesName);
GetStringCatFeatureHash = lib.get<CatBoostLibraryAPI::GetStringCatFeatureHashFunc>(CatBoostLibraryAPI::GetStringCatFeatureHashName);
GetIntegerCatFeatureHash = lib.get<CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc>(CatBoostLibraryAPI::GetIntegerCatFeatureHashName);
GetFloatFeaturesCount = lib.get<CatBoostLibraryAPI::GetFloatFeaturesCountFunc>(CatBoostLibraryAPI::GetFloatFeaturesCountName);
GetCatFeaturesCount = lib.get<CatBoostLibraryAPI::GetCatFeaturesCountFunc>(CatBoostLibraryAPI::GetCatFeaturesCountName);
GetTreeCount = lib.tryGet<CatBoostLibraryAPI::GetTreeCountFunc>(CatBoostLibraryAPI::GetTreeCountName);
GetDimensionsCount = lib.tryGet<CatBoostLibraryAPI::GetDimensionsCountFunc>(CatBoostLibraryAPI::GetDimensionsCountName);
}
CatBoostLibraryHandler::CatBoostLibraryHandler(
const std::string & library_path,
const std::string & model_path)
: loading_start_time(std::chrono::system_clock::now())
, library(std::make_shared<SharedLibrary>(library_path))
, api(*library)
{
model_calcer_handle = api.ModelCalcerCreate();
if (!api.LoadFullModelFromFile(model_calcer_handle, model_path.c_str()))
{
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
"Cannot load CatBoost model: {}", api.GetErrorString());
}
float_features_count = api.GetFloatFeaturesCount(model_calcer_handle);
cat_features_count = api.GetCatFeaturesCount(model_calcer_handle);
tree_count = 1;
if (api.GetDimensionsCount)
tree_count = api.GetDimensionsCount(model_calcer_handle);
loading_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - loading_start_time);
}
CatBoostLibraryHandler::~CatBoostLibraryHandler()
{
api.ModelCalcerDelete(model_calcer_handle);
}
std::chrono::system_clock::time_point CatBoostLibraryHandler::getLoadingStartTime() const
{
return loading_start_time;
}
std::chrono::milliseconds CatBoostLibraryHandler::getLoadingDuration() const
{
return loading_duration;
}
namespace
{
/// Buffer should be allocated with features_count * column->size() elements.
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
template <typename T>
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count)
{
size_t size = column->size();
FieldVisitorConvertToNumber<T> visitor;
for (size_t i = 0; i < size; ++i)
{
/// TODO: Replace with column visitor.
Field field;
column->get(i, field);
*buffer = applyVisitor(visitor, field);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count)
{
size_t size = column.size();
for (size_t i = 0; i < size; ++i)
{
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
PODArray<char> placeFixedStringColumn(const ColumnFixedString & column, const char ** buffer, size_t features_count)
{
size_t size = column.size();
size_t str_size = column.getN();
PODArray<char> data(size * (str_size + 1));
char * data_ptr = data.data();
for (size_t i = 0; i < size; ++i)
{
auto ref = column.getDataAt(i);
memcpy(data_ptr, ref.data, ref.size);
data_ptr[ref.size] = 0;
*buffer = data_ptr;
data_ptr += ref.size + 1;
buffer += features_count;
}
return data;
}
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
template <typename T>
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const T** buffer)
{
if (size == 0)
return nullptr;
size_t column_size = columns[offset]->size();
auto data_column = ColumnVector<T>::create(size * column_size);
T * data = data_column->getData().data();
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (column->isNumeric())
placeColumnAsNumber(column, data + i, size);
}
for (size_t i = 0; i < column_size; ++i)
{
*buffer = data;
++buffer;
data += size;
}
return data_column;
}
/// Place columns into buffer, returns data which was used for fixed string columns.
/// Buffer should contains column->size() values, each value contains size strings.
std::vector<PODArray<char>> placeStringColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer)
{
if (size == 0)
return {};
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
placeStringColumn(*column_string, buffer + i, size);
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
else
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
}
return data;
}
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
void fillCatFeaturesBuffer(
const char *** cat_features, const char ** buffer,
size_t column_size, size_t cat_features_count)
{
for (size_t i = 0; i < column_size; ++i)
{
*cat_features = buffer;
++cat_features;
buffer += cat_features_count;
}
}
/// Calc hash for string cat feature at ps positions.
template <typename Column>
void calcStringHashes(const Column * column, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
{
size_t column_size = column->size();
for (size_t j = 0; j < column_size; ++j)
{
auto ref = column->getDataAt(j);
const_cast<int *>(*buffer)[ps] = api.GetStringCatFeatureHash(ref.data, ref.size);
++buffer;
}
}
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
{
for (size_t j = 0; j < column_size; ++j)
{
const_cast<int *>(*buffer)[ps] = api.GetIntegerCatFeatureHash((*buffer)[ps]);
++buffer;
}
}
/// buffer contains column->size() rows and size columns.
/// For int cat features calc hash inplace.
/// For string cat features calc hash from column rows.
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
{
if (size == 0)
return;
size_t column_size = columns[offset]->size();
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
calcStringHashes(column_string, i, buffer, api);
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
calcStringHashes(column_fixed_string, i, buffer, api);
else
calcIntHashes(column_size, i, buffer, api);
}
}
}
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
/// * CalcModelPredictionFlat if no cat features
/// * CalcModelPrediction if all cat features are strings
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
ColumnFloat64::MutablePtr CatBoostLibraryHandler::evalImpl(
const ColumnRawPtrs & columns,
bool cat_features_are_strings) const
{
std::string error_msg = "Error occurred while applying CatBoost model: ";
size_t column_size = columns.front()->size();
auto result = ColumnFloat64::create(column_size * tree_count);
auto * result_buf = result->getData().data();
if (!column_size)
return result;
/// Prepare float features.
PODArray<const float *> float_features(column_size);
auto * float_features_buf = float_features.data();
/// Store all float data into single column. float_features is a list of pointers to it.
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
if (cat_features_count == 0)
{
if (!api.CalcModelPredictionFlat(model_calcer_handle, column_size,
float_features_buf, float_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
return result;
}
/// Prepare cat features.
if (cat_features_are_strings)
{
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
PODArray<const char **> cat_features(column_size);
auto * cat_features_buf = cat_features.data();
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
cat_features_count, cat_features_holder.data());
if (!api.CalcModelPrediction(model_calcer_handle, column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
else
{
PODArray<const int *> cat_features(column_size);
auto * cat_features_buf = cat_features.data();
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
cat_features_count, cat_features_buf);
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf, api);
if (!api.CalcModelPredictionWithHashedCatFeatures(
model_calcer_handle, column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
return result;
}
size_t CatBoostLibraryHandler::getTreeCount() const
{
std::lock_guard lock(mutex);
return tree_count;
}
ColumnPtr CatBoostLibraryHandler::evaluate(const ColumnRawPtrs & columns) const
{
std::lock_guard lock(mutex);
if (columns.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got empty columns list for CatBoost model.");
if (columns.size() != float_features_count + cat_features_count)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Number of columns is different with number of features: columns size {} float features size {} + cat features size {}",
columns.size(),
float_features_count,
cat_features_count);
for (size_t i = 0; i < float_features_count; ++i)
{
if (!columns[i]->isNumeric())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric to make float feature.", i);
}
}
bool cat_features_are_strings = true;
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
{
const auto * column = columns[i];
if (column->isNumeric())
{
cat_features_are_strings = false;
}
else if (!(typeid_cast<const ColumnString *>(column)
|| typeid_cast<const ColumnFixedString *>(column)))
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric or string.", i);
}
}
auto result = evalImpl(columns, cat_features_are_strings);
if (tree_count == 1)
return result;
size_t column_size = columns.front()->size();
auto * result_buf = result->getData().data();
/// Multiple trees case. Copy data to several columns.
MutableColumns mutable_columns(tree_count);
std::vector<Float64 *> column_ptrs(tree_count);
for (size_t i = 0; i < tree_count; ++i)
{
auto col = ColumnFloat64::create(column_size);
column_ptrs[i] = col->getData().data();
mutable_columns[i] = std::move(col);
}
Float64 * data = result_buf;
for (size_t row = 0; row < column_size; ++row)
{
for (size_t i = 0; i < tree_count; ++i)
{
*column_ptrs[i] = *data;
++column_ptrs[i];
++data;
}
}
return ColumnTuple::create(std::move(mutable_columns));
}
}

View File

@ -0,0 +1,78 @@
#pragma once
#include "CatBoostLibraryAPI.h"
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <Common/SharedLibrary.h>
#include <base/defines.h>
#include <chrono>
#include <mutex>
namespace DB
{
/// Abstracts access to the CatBoost shared library.
class CatBoostLibraryHandler
{
public:
/// Holds pointers to CatBoost library functions
struct APIHolder
{
explicit APIHolder(SharedLibrary & lib);
// NOLINTBEGIN(readability-identifier-naming)
CatBoostLibraryAPI::ModelCalcerCreateFunc ModelCalcerCreate;
CatBoostLibraryAPI::ModelCalcerDeleteFunc ModelCalcerDelete;
CatBoostLibraryAPI::GetErrorStringFunc GetErrorString;
CatBoostLibraryAPI::LoadFullModelFromFileFunc LoadFullModelFromFile;
CatBoostLibraryAPI::CalcModelPredictionFlatFunc CalcModelPredictionFlat;
CatBoostLibraryAPI::CalcModelPredictionFunc CalcModelPrediction;
CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc CalcModelPredictionWithHashedCatFeatures;
CatBoostLibraryAPI::GetStringCatFeatureHashFunc GetStringCatFeatureHash;
CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc GetIntegerCatFeatureHash;
CatBoostLibraryAPI::GetFloatFeaturesCountFunc GetFloatFeaturesCount;
CatBoostLibraryAPI::GetCatFeaturesCountFunc GetCatFeaturesCount;
CatBoostLibraryAPI::GetTreeCountFunc GetTreeCount;
CatBoostLibraryAPI::GetDimensionsCountFunc GetDimensionsCount;
// NOLINTEND(readability-identifier-naming)
};
CatBoostLibraryHandler(
const String & library_path,
const String & model_path);
~CatBoostLibraryHandler();
std::chrono::system_clock::time_point getLoadingStartTime() const;
std::chrono::milliseconds getLoadingDuration() const;
size_t getTreeCount() const;
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
private:
std::chrono::system_clock::time_point loading_start_time;
std::chrono::milliseconds loading_duration;
const SharedLibraryPtr library;
const APIHolder api;
mutable std::mutex mutex;
CatBoostLibraryAPI::ModelCalcerHandle * model_calcer_handle TSA_GUARDED_BY(mutex) TSA_PT_GUARDED_BY(mutex);
size_t float_features_count TSA_GUARDED_BY(mutex);
size_t cat_features_count TSA_GUARDED_BY(mutex);
size_t tree_count TSA_GUARDED_BY(mutex);
ColumnFloat64::MutablePtr evalImpl(const ColumnRawPtrs & columns, bool cat_features_are_strings) const TSA_REQUIRES(mutex);
};
using CatBoostLibraryHandlerPtr = std::shared_ptr<CatBoostLibraryHandler>;
}

View File

@ -0,0 +1,80 @@
#include "CatBoostLibraryHandlerFactory.h"
#include <Common/logger_useful.h>
namespace DB
{
CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance()
{
static CatBoostLibraryHandlerFactory instance;
return instance;
}
CatBoostLibraryHandlerFactory::CatBoostLibraryHandlerFactory()
: log(&Poco::Logger::get("CatBoostLibraryHandlerFactory"))
{
}
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::tryGetModel(const String & model_path, const String & library_path, bool create_if_not_found)
{
std::lock_guard lock(mutex);
auto handler = library_handlers.find(model_path);
bool found = (handler != library_handlers.end());
if (found)
return handler->second;
else
{
if (create_if_not_found)
{
auto new_handler = std::make_shared<CatBoostLibraryHandler>(library_path, model_path);
library_handlers.emplace(model_path, new_handler);
LOG_DEBUG(log, "Loaded catboost library handler for model path '{}'", model_path);
return new_handler;
}
return nullptr;
}
}
void CatBoostLibraryHandlerFactory::removeModel(const String & model_path)
{
std::lock_guard lock(mutex);
bool deleted = library_handlers.erase(model_path);
if (!deleted)
{
LOG_DEBUG(log, "Cannot unload catboost library handler for model path '{}'", model_path);
return;
}
LOG_DEBUG(log, "Unloaded catboost library handler for model path '{}'", model_path);
}
void CatBoostLibraryHandlerFactory::removeAllModels()
{
std::lock_guard lock(mutex);
library_handlers.clear();
LOG_DEBUG(log, "Unloaded all catboost library handlers");
}
ExternalModelInfos CatBoostLibraryHandlerFactory::getModelInfos()
{
std::lock_guard lock(mutex);
ExternalModelInfos result;
for (const auto & handler : library_handlers)
result.push_back({
.model_path = handler.first,
.model_type = "catboost",
.loading_start_time = handler.second->getLoadingStartTime(),
.loading_duration = handler.second->getLoadingDuration()
});
return result;
}
}

View File

@ -0,0 +1,37 @@
#pragma once
#include "CatBoostLibraryHandler.h"
#include <base/defines.h>
#include <Common/ExternalModelInfo.h>
#include <chrono>
#include <mutex>
#include <unordered_map>
namespace DB
{
class CatBoostLibraryHandlerFactory final : private boost::noncopyable
{
public:
static CatBoostLibraryHandlerFactory & instance();
CatBoostLibraryHandlerFactory();
CatBoostLibraryHandlerPtr tryGetModel(const String & model_path, const String & library_path, bool create_if_not_found);
void removeModel(const String & model_path);
void removeAllModels();
ExternalModelInfos getModelInfos();
private:
/// map: model path --> catboost library handler
std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::mutex mutex;
Poco::Logger * log;
};
}

View File

@ -50,6 +50,6 @@ private:
void * lib_data;
};
using SharedLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
using ExternalDictionaryLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
}

View File

@ -1,37 +1,40 @@
#include "ExternalDictionaryLibraryHandlerFactory.h"
#include <Common/logger_useful.h>
namespace DB
{
SharedLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const std::string & dictionary_id)
ExternalDictionaryLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const String & dictionary_id)
{
std::lock_guard lock(mutex);
auto library_handler = library_handlers.find(dictionary_id);
if (library_handler != library_handlers.end())
return library_handler->second;
if (auto handler = library_handlers.find(dictionary_id); handler != library_handlers.end())
return handler->second;
return nullptr;
}
void ExternalDictionaryLibraryHandlerFactory::create(
const std::string & dictionary_id,
const std::string & library_path,
const std::vector<std::string> & library_settings,
const String & dictionary_id,
const String & library_path,
const std::vector<String> & library_settings,
const Block & sample_block,
const std::vector<std::string> & attributes_names)
const std::vector<String> & attributes_names)
{
std::lock_guard lock(mutex);
if (!library_handlers.contains(dictionary_id))
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
else
if (library_handlers.contains(dictionary_id))
{
LOG_WARNING(&Poco::Logger::get("ExternalDictionaryLibraryHandlerFactory"), "Library handler with dictionary id {} already exists", dictionary_id);
return;
}
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
}
bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id)
bool ExternalDictionaryLibraryHandlerFactory::clone(const String & from_dictionary_id, const String & to_dictionary_id)
{
std::lock_guard lock(mutex);
auto from_library_handler = library_handlers.find(from_dictionary_id);
@ -45,7 +48,7 @@ bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dic
}
bool ExternalDictionaryLibraryHandlerFactory::remove(const std::string & dictionary_id)
bool ExternalDictionaryLibraryHandlerFactory::remove(const String & dictionary_id)
{
std::lock_guard lock(mutex);
/// extDict_libDelete is called in destructor.

View File

@ -17,22 +17,22 @@ class ExternalDictionaryLibraryHandlerFactory final : private boost::noncopyable
public:
static ExternalDictionaryLibraryHandlerFactory & instance();
SharedLibraryHandlerPtr get(const std::string & dictionary_id);
ExternalDictionaryLibraryHandlerPtr get(const String & dictionary_id);
void create(
const std::string & dictionary_id,
const std::string & library_path,
const std::vector<std::string> & library_settings,
const String & dictionary_id,
const String & library_path,
const std::vector<String> & library_settings,
const Block & sample_block,
const std::vector<std::string> & attributes_names);
const std::vector<String> & attributes_names);
bool clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id);
bool clone(const String & from_dictionary_id, const String & to_dictionary_id);
bool remove(const std::string & dictionary_id);
bool remove(const String & dictionary_id);
private:
/// map: dict_id -> sharedLibraryHandler
std::unordered_map<std::string, SharedLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::unordered_map<String, ExternalDictionaryLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::mutex mutex;
};

View File

@ -27,12 +27,16 @@ std::unique_ptr<HTTPRequestHandler> LibraryBridgeHandlerFactory::createRequestHa
{
if (uri.getPath() == "/extdict_ping")
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
else if (uri.getPath() == "/catboost_ping")
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
}
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
{
if (uri.getPath() == "/extdict_request")
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
else if (uri.getPath() == "/catboost_request")
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
}
return nullptr;

View File

@ -1,24 +1,32 @@
#include "LibraryBridgeHandlers.h"
#include "CatBoostLibraryHandler.h"
#include "CatBoostLibraryHandlerFactory.h"
#include "ExternalDictionaryLibraryHandler.h"
#include "ExternalDictionaryLibraryHandlerFactory.h"
#include <Formats/FormatFactory.h>
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Common/BridgeProtocolVersion.h>
#include <IO/WriteHelpers.h>
#include <Poco/Net/HTMLForm.h>
#include <Poco/Net/HTTPServerRequest.h>
#include <Poco/Net/HTTPServerResponse.h>
#include <Poco/Net/HTMLForm.h>
#include <Poco/ThreadPool.h>
#include <Processors/Formats/IOutputFormat.h>
#include <Processors/Formats/IInputFormat.h>
#include <QueryPipeline/QueryPipeline.h>
#include <Processors/Executors/CompletedPipelineExecutor.h>
#include <Processors/Executors/PullingPipelineExecutor.h>
#include <Processors/Formats/IInputFormat.h>
#include <Processors/Formats/IOutputFormat.h>
#include <Processors/Sources/SourceFromSingleChunk.h>
#include <QueryPipeline/Pipe.h>
#include <QueryPipeline/QueryPipeline.h>
#include <Server/HTTP/HTMLForm.h>
#include <IO/ReadBufferFromString.h>
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
#include <Formats/NativeReader.h>
#include <Formats/NativeWriter.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeString.h>
namespace DB
@ -31,7 +39,7 @@ namespace ErrorCodes
namespace
{
void processError(HTTPServerResponse & response, const std::string & message)
void processError(HTTPServerResponse & response, const String & message)
{
response.setStatusAndReason(HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
@ -41,7 +49,7 @@ namespace
LOG_WARNING(&Poco::Logger::get("LibraryBridge"), fmt::runtime(message));
}
std::shared_ptr<Block> parseColumns(std::string && column_string)
std::shared_ptr<Block> parseColumns(String && column_string)
{
auto sample_block = std::make_shared<Block>();
auto names_and_types = NamesAndTypesList::parse(column_string);
@ -59,10 +67,10 @@ namespace
return ids;
}
std::vector<std::string> parseNamesFromBinary(const std::string & names_string)
std::vector<String> parseNamesFromBinary(const String & names_string)
{
ReadBufferFromString buf(names_string);
std::vector<std::string> names;
std::vector<String> names;
readVectorBinary(names, buf);
return names;
}
@ -79,13 +87,15 @@ static void writeData(Block data, OutputFormatPtr format)
executor.execute();
}
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
, keep_alive_timeout(keep_alive_timeout_)
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
{
}
void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
LOG_TRACE(log, "Request URI: {}", request.getURI());
@ -97,7 +107,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
version = 0; /// assumed version for too old servers which do not send a version
else
{
String version_str = params.get("version");
const String & version_str = params.get("version");
if (!tryParse(version, version_str))
{
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
@ -124,8 +134,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string method = params.get("method");
std::string dictionary_id = params.get("dictionary_id");
const String & method = params.get("method");
const String & dictionary_id = params.get("dictionary_id");
LOG_TRACE(log, "Library method: '{}', dictionary id: {}", method, dictionary_id);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
@ -141,7 +151,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string from_dictionary_id = params.get("from_dictionary_id");
const String & from_dictionary_id = params.get("from_dictionary_id");
bool cloned = false;
cloned = ExternalDictionaryLibraryHandlerFactory::instance().clone(from_dictionary_id, dictionary_id);
@ -166,7 +176,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string library_path = params.get("library_path");
const String & library_path = params.get("library_path");
if (!params.has("library_settings"))
{
@ -174,10 +184,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
const auto & settings_string = params.get("library_settings");
const String & settings_string = params.get("library_settings");
LOG_DEBUG(log, "Parsing library settings from binary string");
std::vector<std::string> library_settings = parseNamesFromBinary(settings_string);
std::vector<String> library_settings = parseNamesFromBinary(settings_string);
/// Needed for library dictionary
if (!params.has("attributes_names"))
@ -186,10 +196,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
const auto & attributes_string = params.get("attributes_names");
const String & attributes_string = params.get("attributes_names");
LOG_DEBUG(log, "Parsing attributes names from binary string");
std::vector<std::string> attributes_names = parseNamesFromBinary(attributes_string);
std::vector<String> attributes_names = parseNamesFromBinary(attributes_string);
/// Needed to parse block from binary string format
if (!params.has("sample_block"))
@ -197,7 +207,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
processError(response, "No 'sample_block' in request URL");
return;
}
std::string sample_block_string = params.get("sample_block");
String sample_block_string = params.get("sample_block");
std::shared_ptr<Block> sample_block;
try
@ -297,7 +307,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string requested_block_string = params.get("requested_block_sample");
String requested_block_string = params.get("requested_block_sample");
std::shared_ptr<Block> requested_sample_block;
try
@ -332,7 +342,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
}
else
{
LOG_WARNING(log, "Unknown library method: '{}'", method);
processError(response, "Unknown library method '" + method + "'");
LOG_ERROR(log, "Unknown library method: '{}'", method);
}
}
catch (...)
@ -362,6 +373,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
}
}
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
@ -369,6 +381,7 @@ ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExi
{
}
void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
try
@ -382,7 +395,7 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
return;
}
std::string dictionary_id = params.get("dictionary_id");
const String & dictionary_id = params.get("dictionary_id");
auto library_handler = ExternalDictionaryLibraryHandlerFactory::instance().get(dictionary_id);
@ -399,4 +412,230 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
}
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(&Poco::Logger::get("CatBoostLibraryBridgeRequestHandler"))
{
}
void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
LOG_TRACE(log, "Request URI: {}", request.getURI());
HTMLForm params(getContext()->getSettingsRef(), request);
size_t version;
if (!params.has("version"))
version = 0; /// assumed version for too old servers which do not send a version
else
{
const String & version_str = params.get("version");
if (!tryParse(version, version_str))
{
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
return;
}
}
if (version != LIBRARY_BRIDGE_PROTOCOL_VERSION)
{
/// backwards compatibility is considered unnecessary for now, just let the user know that the server and the bridge must be upgraded together
processError(response, "Server and library-bridge have different versions: '" + std::to_string(version) + "' vs. '" + std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION) + "'");
return;
}
if (!params.has("method"))
{
processError(response, "No 'method' in request URL");
return;
}
const String & method = params.get("method");
LOG_TRACE(log, "Library method: '{}'", method);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
try
{
if (method == "catboost_list")
{
ExternalModelInfos model_infos = CatBoostLibraryHandlerFactory::instance().getModelInfos();
writeIntBinary(static_cast<UInt64>(model_infos.size()), out);
for (const auto & info : model_infos)
{
writeStringBinary(info.model_path, out);
writeStringBinary(info.model_type, out);
UInt64 t = std::chrono::system_clock::to_time_t(info.loading_start_time);
writeIntBinary(t, out);
t = info.loading_duration.count();
writeIntBinary(t, out);
}
}
else if (method == "catboost_removeModel")
{
auto & read_buf = request.getStream();
params.read(read_buf);
if (!params.has("model_path"))
{
processError(response, "No 'model_path' in request URL");
return;
}
const String & model_path = params.get("model_path");
CatBoostLibraryHandlerFactory::instance().removeModel(model_path);
String res = "1";
writeStringBinary(res, out);
}
else if (method == "catboost_removeAllModels")
{
CatBoostLibraryHandlerFactory::instance().removeAllModels();
String res = "1";
writeStringBinary(res, out);
}
else if (method == "catboost_GetTreeCount")
{
auto & read_buf = request.getStream();
params.read(read_buf);
if (!params.has("library_path"))
{
processError(response, "No 'library_path' in request URL");
return;
}
const String & library_path = params.get("library_path");
if (!params.has("model_path"))
{
processError(response, "No 'model_path' in request URL");
return;
}
const String & model_path = params.get("model_path");
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().tryGetModel(model_path, library_path, /*create_if_not_found*/ true);
size_t tree_count = catboost_handler->getTreeCount();
writeIntBinary(tree_count, out);
}
else if (method == "catboost_libEvaluate")
{
auto & read_buf = request.getStream();
params.read(read_buf);
if (!params.has("model_path"))
{
processError(response, "No 'model_path' in request URL");
return;
}
const String & model_path = params.get("model_path");
if (!params.has("data"))
{
processError(response, "No 'data' in request URL");
return;
}
const String & data = params.get("data");
ReadBufferFromString string_read_buf(data);
NativeReader deserializer(string_read_buf, /*server_revision*/ 0);
Block block_read = deserializer.read();
Columns col_ptrs = block_read.getColumns();
ColumnRawPtrs col_raw_ptrs;
for (const auto & p : col_ptrs)
col_raw_ptrs.push_back(&*p);
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().tryGetModel(model_path, "DummyLibraryPath", /*create_if_not_found*/ false);
if (!catboost_handler)
{
processError(response, "CatBoost library is not loaded for model '" + model_path + "'. Please try again.");
return;
}
ColumnPtr res_col = catboost_handler->evaluate(col_raw_ptrs);
DataTypePtr res_col_type = std::make_shared<DataTypeFloat64>();
String res_col_name = "res_col";
ColumnsWithTypeAndName res_cols_with_type_and_name = {{res_col, res_col_type, res_col_name}};
Block block_write(res_cols_with_type_and_name);
NativeWriter serializer{out, /*client_revision*/ 0, block_write};
serializer.write(block_write);
}
else
{
processError(response, "Unknown library method '" + method + "'");
LOG_ERROR(log, "Unknown library method: '{}'", method);
}
}
catch (...)
{
auto message = getCurrentExceptionMessage(true);
LOG_ERROR(log, "Failed to process request. Error: {}", message);
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR, message); // can't call process_error, because of too soon response sending
try
{
writeStringBinary(message, out);
out.finalize();
}
catch (...)
{
tryLogCurrentException(log);
}
}
try
{
out.finalize();
}
catch (...)
{
tryLogCurrentException(log);
}
}
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(&Poco::Logger::get("CatBoostLibraryBridgeExistsHandler"))
{
}
void CatBoostLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
try
{
LOG_TRACE(log, "Request URI: {}", request.getURI());
HTMLForm params(getContext()->getSettingsRef(), request);
String res = "1";
setResponseDefaultHeaders(response, keep_alive_timeout);
LOG_TRACE(log, "Sending ping response: {}", res);
response.sendBuffer(res.data(), res.size());
}
catch (...)
{
tryLogCurrentException("PingHandler");
}
}
}

View File

@ -1,9 +1,8 @@
#pragma once
#include <Common/logger_useful.h>
#include <Interpreters/Context.h>
#include <Server/HTTP/HTTPRequestHandler.h>
#include <Common/logger_useful.h>
#include "ExternalDictionaryLibraryHandler.h"
namespace DB
@ -26,11 +25,12 @@ public:
private:
static constexpr inline auto FORMAT = "RowBinary";
const size_t keep_alive_timeout;
Poco::Logger * log;
size_t keep_alive_timeout;
};
// Handler for checking if the external dictionary library is loaded (used for handshake)
class ExternalDictionaryLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
@ -43,4 +43,47 @@ private:
Poco::Logger * log;
};
/// Handler for requests to catboost library. The call protocol is as follows:
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge. It contains a library path (e.g /home/user/libcatboost.so) and
/// a model path (e.g. /home/user/model.bin). This loads the catboost library handler associated with the model path, then executes
/// GetTreeCount() on the library handler and sends the result back to the server.
/// (2) Send "catboost_Evaluate" from the server to the bridge. It contains a model path and the features to run the interference on. Step
/// (2) is called multiple times (once per chunk) by the server.
///
/// We would ideally like to have steps (1) and (2) in one atomic handler but can't because the evaluation on the server side is divided
/// into two dependent phases: FunctionCatBoostEvaluate::getReturnTypeImpl() and ::executeImpl(). So the model may in principle be unloaded
/// from the library-bridge between steps (1) and (2). Step (2) checks if that is the case and fails gracefully. This is okay because that
/// situation considered exceptional and rare.
///
/// An update of a model is performed by unloading it. The first call to "catboost_GetTreeCount" brings it into memory again.
///
/// Further handlers are provided for unloading a specific model, for unloading all models or for retrieving information about the loaded
/// models for display in a system view.
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
private:
const size_t keep_alive_timeout;
Poco::Logger * log;
};
// Handler for pinging the library-bridge for catboost access (used for handshake)
class CatBoostLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
private:
const size_t keep_alive_timeout;
Poco::Logger * log;
};
}

View File

@ -51,7 +51,6 @@
#include <Interpreters/DNSCacheUpdater.h>
#include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/ExternalModelsLoader.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/loadMetadata.h>
#include <Interpreters/UserDefinedSQLObjectsLoader.h>
@ -1158,7 +1157,6 @@ int Server::main(const std::vector<std::string> & /*args*/)
global_context->setExternalAuthenticatorsConfig(*config);
global_context->loadOrReloadDictionaries(*config);
global_context->loadOrReloadModels(*config);
global_context->loadOrReloadUserDefinedExecutableFunctions(*config);
global_context->setRemoteHostFilter(*config);
@ -1739,17 +1737,6 @@ int Server::main(const std::vector<std::string> & /*args*/)
throw;
}
/// try to load models immediately, throw on error and die
try
{
global_context->loadOrReloadModels(config());
}
catch (...)
{
tryLogCurrentException(log, "Caught exception while loading dictionaries.");
throw;
}
/// try to load user defined executable functions, throw on error and die
try
{

View File

@ -0,0 +1,194 @@
#include "CatBoostLibraryBridgeHelper.h"
#include <Columns/ColumnsNumber.h>
#include <Common/escapeForFileName.h>
#include <Core/Block.h>
#include <DataTypes/DataTypesNumber.h>
#include <Formats/NativeReader.h>
#include <Formats/NativeWriter.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
#include <Poco/Net/HTTPRequest.h>
#include <random>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(
ContextPtr context_,
std::optional<String> model_path_,
std::optional<String> library_path_)
: LibraryBridgeHelper(context_->getGlobalContext())
, model_path(model_path_)
, library_path(library_path_)
{
}
Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const
{
auto uri = createBaseURI();
uri.setPath(PING_HANDLER);
return uri;
}
Poco::URI CatBoostLibraryBridgeHelper::getMainURI() const
{
auto uri = createBaseURI();
uri.setPath(MAIN_HANDLER);
return uri;
}
Poco::URI CatBoostLibraryBridgeHelper::createRequestURI(const String & method) const
{
auto uri = getMainURI();
uri.addQueryParameter("version", std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION));
uri.addQueryParameter("method", method);
return uri;
}
bool CatBoostLibraryBridgeHelper::bridgeHandShake()
{
String result;
try
{
ReadWriteBufferFromHTTP buf(getPingURI(), Poco::Net::HTTPRequest::HTTP_GET, {}, http_timeouts, credentials);
readString(result, buf);
}
catch (...)
{
tryLogCurrentException(log);
return false;
}
if (result != "1")
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected message from library bridge: {}. Check that bridge and server have the same version.", result);
return true;
}
ExternalModelInfos CatBoostLibraryBridgeHelper::listModels()
{
startBridgeSync();
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_LIST_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[](std::ostream &) {},
http_timeouts, credentials);
ExternalModelInfos result;
UInt64 num_rows;
readIntBinary(num_rows, buf);
for (UInt64 i = 0; i < num_rows; ++i)
{
ExternalModelInfo info;
readStringBinary(info.model_path, buf);
readStringBinary(info.model_type, buf);
UInt64 t;
readIntBinary(t, buf);
info.loading_start_time = std::chrono::system_clock::from_time_t(t);
readIntBinary(t, buf);
info.loading_duration = std::chrono::milliseconds(t);
result.push_back(info);
}
return result;
}
void CatBoostLibraryBridgeHelper::removeModel()
{
startBridgeSync();
assert(model_path);
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_REMOVEMODEL_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[this](std::ostream & os)
{
os << "model_path=" << escapeForFileName(*model_path);
},
http_timeouts, credentials);
String result;
readStringBinary(result, buf);
assert(result == "1");
}
void CatBoostLibraryBridgeHelper::removeAllModels()
{
startBridgeSync();
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_REMOVEALLMODELS_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[](std::ostream &){},
http_timeouts, credentials);
String result;
readStringBinary(result, buf);
assert(result == "1");
}
size_t CatBoostLibraryBridgeHelper::getTreeCount()
{
startBridgeSync();
assert(model_path && library_path);
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_GETTREECOUNT_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[this](std::ostream & os)
{
os << "library_path=" << escapeForFileName(*library_path) << "&";
os << "model_path=" << escapeForFileName(*model_path);
},
http_timeouts, credentials);
size_t result;
readIntBinary(result, buf);
return result;
}
ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns)
{
startBridgeSync();
WriteBufferFromOwnString string_write_buf;
Block block(columns);
NativeWriter serializer(string_write_buf, /*client_revision*/ 0, block);
serializer.write(block);
assert(model_path);
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_LIB_EVALUATE_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[this, serialized = string_write_buf.str()](std::ostream & os)
{
os << "model_path=" << escapeForFileName(*model_path) << "&";
os << "data=" << escapeForFileName(serialized);
},
http_timeouts, credentials);
NativeReader deserializer(buf, /*server_revision*/ 0);
Block block_read = deserializer.read();
return block_read.getColumns()[0];
}
}

View File

@ -0,0 +1,53 @@
#pragma once
#include <BridgeHelper/LibraryBridgeHelper.h>
#include <Common/ExternalModelInfo.h>
#include <DataTypes/IDataType.h>
#include <IO/ReadWriteBufferFromHTTP.h>
#include <Interpreters/Context.h>
#include <Poco/URI.h>
#include <optional>
namespace DB
{
class CatBoostLibraryBridgeHelper : public LibraryBridgeHelper
{
public:
static constexpr inline auto PING_HANDLER = "/catboost_ping";
static constexpr inline auto MAIN_HANDLER = "/catboost_request";
explicit CatBoostLibraryBridgeHelper(
ContextPtr context_,
std::optional<String> model_path_ = std::nullopt,
std::optional<String> library_path_ = std::nullopt);
ExternalModelInfos listModels();
void removeModel(); /// requires model_path
void removeAllModels();
size_t getTreeCount(); /// requires model_path and library_path
ColumnPtr evaluate(const ColumnsWithTypeAndName & columns); /// requires model_path
protected:
Poco::URI getPingURI() const override;
Poco::URI getMainURI() const override;
bool bridgeHandShake() override;
private:
static constexpr inline auto CATBOOST_LIST_METHOD = "catboost_list";
static constexpr inline auto CATBOOST_REMOVEMODEL_METHOD = "catboost_removeModel";
static constexpr inline auto CATBOOST_REMOVEALLMODELS_METHOD = "catboost_removeAllModels";
static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount";
static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate";
Poco::URI createRequestURI(const String & method) const;
const std::optional<String> model_path;
const std::optional<String> library_path;
};
}

View File

@ -12,8 +12,8 @@
namespace DB
{
/// Common base class for XDBC and Library bridge helpers.
/// Contains helper methods to check/start bridge sync.
/// Base class for server-side bridge helpers, e.g. xdbc-bridge and library-bridge.
/// Contains helper methods to check/start bridge sync
class IBridgeHelper: protected WithContext
{

View File

@ -176,10 +176,10 @@ static void tryLogCurrentExceptionImpl(Poco::Logger * logger, const std::string
void tryLogCurrentException(const char * log_name, const std::string & start_of_message)
{
/// Under high memory pressure, any new allocation will definitelly lead
/// to MEMORY_LIMIT_EXCEEDED exception.
/// Under high memory pressure, new allocations throw a
/// MEMORY_LIMIT_EXCEEDED exception.
///
/// And in this case the exception will not be logged, so let's block the
/// In this case the exception will not be logged, so let's block the
/// MemoryTracker until the exception will be logged.
LockMemoryExceptionInThread lock_memory_tracker(VariableContext::Global);
@ -189,8 +189,8 @@ void tryLogCurrentException(const char * log_name, const std::string & start_of_
void tryLogCurrentException(Poco::Logger * logger, const std::string & start_of_message)
{
/// Under high memory pressure, any new allocation will definitelly lead
/// to MEMORY_LIMIT_EXCEEDED exception.
/// Under high memory pressure, new allocations throw a
/// MEMORY_LIMIT_EXCEEDED exception.
///
/// And in this case the exception will not be logged, so let's block the
/// MemoryTracker until the exception will be logged.

View File

@ -0,0 +1,20 @@
#pragma once
#include <vector>
#include <base/types.h>
namespace DB
{
/// Details about external machine learning model, used by clickhouse-server and clickhouse-library-bridge
struct ExternalModelInfo
{
String model_path;
String model_type;
std::chrono::system_clock::time_point loading_start_time; /// serialized as std::time_t
std::chrono::milliseconds loading_duration; /// serialized as UInt64
};
using ExternalModelInfos = std::vector<ExternalModelInfo>;
}

View File

@ -1,18 +1,18 @@
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <base/range.h>
#include <Interpreters/Context.h>
#include <Interpreters/ExternalModelsLoader.h>
#include <Columns/ColumnString.h>
#include <string>
#include <memory>
#include <DataTypes/DataTypeNullable.h>
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
#include <BridgeHelper/IBridgeHelper.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <DataTypes/DataTypeTuple.h>
#include <Columns/ColumnsNumber.h>
#include <Common/assert_cast.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context.h>
#include <Interpreters/Context_fwd.h>
@ -21,66 +21,80 @@ namespace DB
namespace ErrorCodes
{
extern const int FILE_DOESNT_EXIST;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
extern const int ILLEGAL_COLUMN;
}
class ExternalModelsLoader;
/// Evaluate external model.
/// First argument - model name, the others - model arguments.
/// * for CatBoost model - float features first, then categorical
/// Result - Float64.
class FunctionModelEvaluate final : public IFunction
/// Evaluate CatBoost model.
/// - Arguments: float features first, then categorical features.
/// - Result: Float64.
class FunctionCatBoostEvaluate final : public IFunction, WithContext
{
private:
mutable std::unique_ptr<CatBoostLibraryBridgeHelper> bridge_helper;
public:
static constexpr auto name = "modelEvaluate";
static constexpr auto name = "catboostEvaluate";
static FunctionPtr create(ContextPtr context)
{
return std::make_shared<FunctionModelEvaluate>(context->getExternalModelsLoader());
}
explicit FunctionModelEvaluate(const ExternalModelsLoader & models_loader_)
: models_loader(models_loader_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionCatBoostEvaluate>(context_); }
explicit FunctionCatBoostEvaluate(ContextPtr context_) : WithContext(context_) {}
String getName() const override { return name; }
bool isVariadic() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool isDeterministic() const override { return false; }
bool useDefaultImplementationForNulls() const override { return false; }
size_t getNumberOfArguments() const override { return 0; }
void initBridge(const ColumnConst * name_col) const
{
String library_path = getContext()->getConfigRef().getString("catboost_lib_path");
if (!std::filesystem::exists(library_path))
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load library {}: file doesn't exist", library_path);
String model_path = name_col->getValue<String>();
if (!std::filesystem::exists(model_path))
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load model {}: file doesn't exist", model_path);
bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), model_path, library_path);
}
DataTypePtr getReturnTypeFromLibraryBridge() const
{
size_t tree_count = bridge_helper->getTreeCount();
auto type = std::make_shared<DataTypeFloat64>();
if (tree_count == 1)
return type;
DataTypes types(tree_count, type);
return std::make_shared<DataTypeTuple>(types);
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() < 2)
throw Exception("Function " + getName() + " expects at least 2 arguments",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION);
throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} expects at least 2 arguments", getName());
if (!isString(arguments[0].type))
throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of first argument of function {}, expected a string.", arguments[0].type->getName(), getName());
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!name_col)
throw Exception("First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN);
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
initBridge(name_col);
auto type = getReturnTypeFromLibraryBridge();
bool has_nullable = false;
for (size_t i = 1; i < arguments.size(); ++i)
has_nullable = has_nullable || arguments[i].type->isNullable();
auto model = models_loader.getModel(name_col->getValue<String>());
auto type = model->getReturnType();
if (has_nullable)
{
if (const auto * tuple = typeid_cast<const DataTypeTuple *>(type.get()))
@ -98,31 +112,25 @@ public:
return type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
{
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!name_col)
throw Exception("First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN);
auto model = models_loader.getModel(name_col->getValue<String>());
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
ColumnRawPtrs column_ptrs;
Columns materialized_columns;
ColumnPtr null_map;
column_ptrs.reserve(arguments.size());
for (auto arg : collections::range(1, arguments.size()))
ColumnsWithTypeAndName feature_arguments(arguments.begin() + 1, arguments.end());
for (auto & arg : feature_arguments)
{
const auto & column = arguments[arg].column;
column_ptrs.push_back(column.get());
if (auto full_column = column->convertToFullColumnIfConst())
if (auto full_column = arg.column->convertToFullColumnIfConst())
{
materialized_columns.push_back(full_column);
column_ptrs.back() = full_column.get();
arg.column = full_column;
}
if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(*column_ptrs.back()))
if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(&*arg.column))
{
if (!null_map)
null_map = col_nullable->getNullMapColumnPtr();
@ -140,11 +148,12 @@ public:
null_map = std::move(mut_null_map);
}
column_ptrs.back() = &col_nullable->getNestedColumn();
arg.column = col_nullable->getNestedColumn().getPtr();
arg.type = static_cast<const DataTypeNullable &>(*arg.type).getNestedType();
}
}
auto res = model->evaluate(column_ptrs);
auto res = bridge_helper->evaluate(feature_arguments);
if (null_map)
{
@ -162,15 +171,12 @@ public:
return res;
}
private:
const ExternalModelsLoader & models_loader;
};
REGISTER_FUNCTION(ExternalModels)
REGISTER_FUNCTION(CatBoostEvaluate)
{
factory.registerFunction<FunctionModelEvaluate>();
factory.registerFunction<FunctionCatBoostEvaluate>();
}
}

View File

@ -1,525 +0,0 @@
#include "CatBoostModel.h"
#include <Common/FieldVisitorConvertToNumber.h>
#include <mutex>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnTuple.h>
#include <Common/typeid_cast.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Common/PODArray.h>
#include <Common/SharedLibrary.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
extern const int CANNOT_LOAD_CATBOOST_MODEL;
extern const int CANNOT_APPLY_CATBOOST_MODEL;
}
/// CatBoost wrapper interface functions.
class CatBoostWrapperAPI
{
public:
using ModelCalcerHandle = void;
ModelCalcerHandle * (* ModelCalcerCreate)(); // NOLINT
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer); // NOLINT
const char * (* GetErrorString)(); // NOLINT
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename); // NOLINT
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount, // NOLINT
const float ** floatFeatures, size_t floatFeaturesSize,
double * result, size_t resultSize);
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount, // NOLINT
const float ** floatFeatures, size_t floatFeaturesSize,
const char *** catFeatures, size_t catFeaturesSize,
double * result, size_t resultSize);
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount, // NOLINT
const float ** floatFeatures, size_t floatFeaturesSize,
const int ** catFeatures, size_t catFeaturesSize,
double * result, size_t resultSize);
int (* GetStringCatFeatureHash)(const char * data, size_t size); // NOLINT
int (* GetIntegerCatFeatureHash)(uint64_t val); // NOLINT
size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer); // NOLINT
size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer); // NOLINT
size_t (* GetTreeCount)(ModelCalcerHandle* modelHandle); // NOLINT
size_t (* GetDimensionsCount)(ModelCalcerHandle* modelHandle); // NOLINT
bool (* CheckModelMetadataHasKey)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize); // NOLINT
size_t (*GetModelInfoValueSize)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize); // NOLINT
const char* (*GetModelInfoValue)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize); // NOLINT
};
class CatBoostModelHolder
{
private:
CatBoostWrapperAPI::ModelCalcerHandle * handle;
const CatBoostWrapperAPI * api;
public:
explicit CatBoostModelHolder(const CatBoostWrapperAPI * api_) : api(api_) { handle = api->ModelCalcerCreate(); }
~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
};
/// Holds CatBoost wrapper library and provides wrapper interface.
class CatBoostLibHolder
{
public:
explicit CatBoostLibHolder(std::string lib_path_) : lib_path(std::move(lib_path_)), lib(lib_path) { initAPI(); }
const CatBoostWrapperAPI & getAPI() const { return api; }
const std::string & getCurrentPath() const { return lib_path; }
private:
CatBoostWrapperAPI api;
std::string lib_path;
SharedLibrary lib;
void initAPI()
{
load(api.ModelCalcerCreate, "ModelCalcerCreate");
load(api.ModelCalcerDelete, "ModelCalcerDelete");
load(api.GetErrorString, "GetErrorString");
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
load(api.CalcModelPrediction, "CalcModelPrediction");
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
load(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
load(api.GetCatFeaturesCount, "GetCatFeaturesCount");
tryLoad(api.CheckModelMetadataHasKey, "CheckModelMetadataHasKey");
tryLoad(api.GetModelInfoValueSize, "GetModelInfoValueSize");
tryLoad(api.GetModelInfoValue, "GetModelInfoValue");
tryLoad(api.GetTreeCount, "GetTreeCount");
tryLoad(api.GetDimensionsCount, "GetDimensionsCount");
}
template <typename T>
void load(T& func, const std::string & name) { func = lib.get<T>(name); }
template <typename T>
void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
};
std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
{
static std::shared_ptr<CatBoostLibHolder> ptr;
static std::mutex mutex;
std::lock_guard lock(mutex);
if (!ptr || ptr->getCurrentPath() != lib_path)
ptr = std::make_shared<CatBoostLibHolder>(lib_path);
return ptr;
}
class CatBoostModelImpl
{
public:
CatBoostModelImpl(const CatBoostWrapperAPI * api_, const std::string & model_path) : api(api_)
{
handle = std::make_unique<CatBoostModelHolder>(api);
if (!handle)
{
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
"Cannot create CatBoost model: {}",
api->GetErrorString());
}
if (!api->LoadFullModelFromFile(handle->get(), model_path.c_str()))
{
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
"Cannot load CatBoost model: {}",
api->GetErrorString());
}
float_features_count = api->GetFloatFeaturesCount(handle->get());
cat_features_count = api->GetCatFeaturesCount(handle->get());
tree_count = 1;
if (api->GetDimensionsCount)
tree_count = api->GetDimensionsCount(handle->get());
}
ColumnPtr evaluate(const ColumnRawPtrs & columns) const
{
if (columns.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got empty columns list for CatBoost model.");
if (columns.size() != float_features_count + cat_features_count)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Number of columns is different with number of features: columns size {} float features size {} + cat features size {}",
columns.size(),
float_features_count,
cat_features_count);
for (size_t i = 0; i < float_features_count; ++i)
{
if (!columns[i]->isNumeric())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric to make float feature.", i);
}
}
bool cat_features_are_strings = true;
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
{
const auto * column = columns[i];
if (column->isNumeric())
{
cat_features_are_strings = false;
}
else if (!(typeid_cast<const ColumnString *>(column)
|| typeid_cast<const ColumnFixedString *>(column)))
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric or string.", i);
}
}
auto result = evalImpl(columns, cat_features_are_strings);
if (tree_count == 1)
return result;
size_t column_size = columns.front()->size();
auto * result_buf = result->getData().data();
/// Multiple trees case. Copy data to several columns.
MutableColumns mutable_columns(tree_count);
std::vector<Float64 *> column_ptrs(tree_count);
for (size_t i = 0; i < tree_count; ++i)
{
auto col = ColumnFloat64::create(column_size);
column_ptrs[i] = col->getData().data();
mutable_columns[i] = std::move(col);
}
Float64 * data = result_buf;
for (size_t row = 0; row < column_size; ++row)
{
for (size_t i = 0; i < tree_count; ++i)
{
*column_ptrs[i] = *data;
++column_ptrs[i];
++data;
}
}
return ColumnTuple::create(std::move(mutable_columns));
}
size_t getFloatFeaturesCount() const { return float_features_count; }
size_t getCatFeaturesCount() const { return cat_features_count; }
size_t getTreeCount() const { return tree_count; }
private:
std::unique_ptr<CatBoostModelHolder> handle;
const CatBoostWrapperAPI * api;
size_t float_features_count;
size_t cat_features_count;
size_t tree_count;
/// Buffer should be allocated with features_count * column->size() elements.
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
template <typename T>
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
{
size_t size = column->size();
FieldVisitorConvertToNumber<T> visitor;
for (size_t i = 0; i < size; ++i)
{
/// TODO: Replace with column visitor.
Field field;
column->get(i, field);
*buffer = applyVisitor(visitor, field);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
static void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count)
{
size_t size = column.size();
for (size_t i = 0; i < size; ++i)
{
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
static PODArray<char> placeFixedStringColumn(
const ColumnFixedString & column, const char ** buffer, size_t features_count)
{
size_t size = column.size();
size_t str_size = column.getN();
PODArray<char> data(size * (str_size + 1));
char * data_ptr = data.data();
for (size_t i = 0; i < size; ++i)
{
auto ref = column.getDataAt(i);
memcpy(data_ptr, ref.data, ref.size);
data_ptr[ref.size] = 0;
*buffer = data_ptr;
data_ptr += ref.size + 1;
buffer += features_count;
}
return data;
}
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
template <typename T>
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns,
size_t offset, size_t size, const T** buffer) const
{
if (size == 0)
return nullptr;
size_t column_size = columns[offset]->size();
auto data_column = ColumnVector<T>::create(size * column_size);
T * data = data_column->getData().data();
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (column->isNumeric())
placeColumnAsNumber(column, data + i, size);
}
for (size_t i = 0; i < column_size; ++i)
{
*buffer = data;
++buffer;
data += size;
}
return data_column;
}
/// Place columns into buffer, returns data which was used for fixed string columns.
/// Buffer should contains column->size() values, each value contains size strings.
static std::vector<PODArray<char>> placeStringColumns(
const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer)
{
if (size == 0)
return {};
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
placeStringColumn(*column_string, buffer + i, size);
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
else
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
}
return data;
}
/// Calc hash for string cat feature at ps positions.
template <typename Column>
void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
{
size_t column_size = column->size();
for (size_t j = 0; j < column_size; ++j)
{
auto ref = column->getDataAt(j);
const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
++buffer;
}
}
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
{
for (size_t j = 0; j < column_size; ++j)
{
const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
++buffer;
}
}
/// buffer contains column->size() rows and size columns.
/// For int cat features calc hash inplace.
/// For string cat features calc hash from column rows.
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer) const
{
if (size == 0)
return;
size_t column_size = columns[offset]->size();
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
calcStringHashes(column_string, i, buffer);
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
calcStringHashes(column_fixed_string, i, buffer);
else
calcIntHashes(column_size, i, buffer);
}
}
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
size_t column_size) const
{
for (size_t i = 0; i < column_size; ++i)
{
*cat_features = buffer;
++cat_features;
buffer += cat_features_count;
}
}
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
/// * CalcModelPredictionFlat if no cat features
/// * CalcModelPrediction if all cat features are strings
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
ColumnFloat64::MutablePtr evalImpl(
const ColumnRawPtrs & columns,
bool cat_features_are_strings) const
{
std::string error_msg = "Error occurred while applying CatBoost model: ";
size_t column_size = columns.front()->size();
auto result = ColumnFloat64::create(column_size * tree_count);
auto * result_buf = result->getData().data();
if (!column_size)
return result;
/// Prepare float features.
PODArray<const float *> float_features(column_size);
auto * float_features_buf = float_features.data();
/// Store all float data into single column. float_features is a list of pointers to it.
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
if (cat_features_count == 0)
{
if (!api->CalcModelPredictionFlat(handle->get(), column_size,
float_features_buf, float_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
return result;
}
/// Prepare cat features.
if (cat_features_are_strings)
{
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
PODArray<const char **> cat_features(column_size);
auto * cat_features_buf = cat_features.data();
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size);
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
cat_features_count, cat_features_holder.data());
if (!api->CalcModelPrediction(handle->get(), column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
else
{
PODArray<const int *> cat_features(column_size);
auto * cat_features_buf = cat_features.data();
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
cat_features_count, cat_features_buf);
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
if (!api->CalcModelPredictionWithHashedCatFeatures(
handle->get(), column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
return result;
}
};
CatBoostModel::CatBoostModel(std::string name_, std::string model_path_, std::string lib_path_,
const ExternalLoadableLifetime & lifetime_)
: name(std::move(name_)), model_path(std::move(model_path_)), lib_path(std::move(lib_path_)), lifetime(lifetime_)
{
api_provider = getCatBoostWrapperHolder(lib_path);
api = &api_provider->getAPI();
model = std::make_unique<CatBoostModelImpl>(api, model_path);
}
CatBoostModel::~CatBoostModel() = default;
size_t CatBoostModel::getFloatFeaturesCount() const
{
return model->getFloatFeaturesCount();
}
size_t CatBoostModel::getCatFeaturesCount() const
{
return model->getCatFeaturesCount();
}
size_t CatBoostModel::getTreeCount() const
{
return model->getTreeCount();
}
DataTypePtr CatBoostModel::getReturnType() const
{
size_t tree_count = getTreeCount();
auto type = std::make_shared<DataTypeFloat64>();
if (tree_count == 1)
return type;
DataTypes types(tree_count, type);
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr CatBoostModel::evaluate(const ColumnRawPtrs & columns) const
{
if (!model)
throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
return model->evaluate(columns);
}
}

View File

@ -1,73 +0,0 @@
#pragma once
#include <Interpreters/IExternalLoadable.h>
#include <Columns/IColumn.h>
#include <Columns/ColumnsNumber.h>
namespace DB
{
class CatBoostLibHolder;
class CatBoostWrapperAPI;
class CatBoostModelImpl;
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
/// General ML model evaluator interface.
class IMLModel : public IExternalLoadable
{
public:
IMLModel() = default;
virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0;
virtual std::string getTypeName() const = 0;
virtual DataTypePtr getReturnType() const = 0;
virtual ~IMLModel() override = default;
};
class CatBoostModel : public IMLModel
{
public:
CatBoostModel(std::string name, std::string model_path,
std::string lib_path, const ExternalLoadableLifetime & lifetime);
~CatBoostModel() override;
ColumnPtr evaluate(const ColumnRawPtrs & columns) const override;
std::string getTypeName() const override { return "catboost"; }
size_t getFloatFeaturesCount() const;
size_t getCatFeaturesCount() const;
size_t getTreeCount() const;
DataTypePtr getReturnType() const override;
/// IExternalLoadable interface.
const ExternalLoadableLifetime & getLifetime() const override { return lifetime; }
std::string getLoadableName() const override { return name; }
bool supportUpdates() const override { return true; }
bool isModified() const override { return true; }
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<CatBoostModel>(name, model_path, lib_path, lifetime);
}
private:
const std::string name;
std::string model_path;
std::string lib_path;
ExternalLoadableLifetime lifetime;
std::shared_ptr<CatBoostLibHolder> api_provider;
const CatBoostWrapperAPI * api;
std::unique_ptr<CatBoostModelImpl> model;
void init();
};
}

View File

@ -52,7 +52,6 @@
#include <Interpreters/EmbeddedDictionaries.h>
#include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/ExternalUserDefinedExecutableFunctionsLoader.h>
#include <Interpreters/ExternalModelsLoader.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/InterserverCredentials.h>
@ -153,7 +152,6 @@ struct ContextSharedPart : boost::noncopyable
mutable std::mutex embedded_dictionaries_mutex;
mutable std::mutex external_dictionaries_mutex;
mutable std::mutex external_user_defined_executable_functions_mutex;
mutable std::mutex external_models_mutex;
/// Separate mutex for storage policies. During server startup we may
/// initialize some important storages (system logs with MergeTree engine)
/// under context lock.
@ -191,9 +189,7 @@ struct ContextSharedPart : boost::noncopyable
mutable std::unique_ptr<EmbeddedDictionaries> embedded_dictionaries; /// Metrica's dictionaries. Have lazy initialization.
mutable std::unique_ptr<ExternalDictionariesLoader> external_dictionaries_loader;
mutable std::unique_ptr<ExternalUserDefinedExecutableFunctionsLoader> external_user_defined_executable_functions_loader;
mutable std::unique_ptr<ExternalModelsLoader> external_models_loader;
ExternalLoaderXMLConfigRepository * external_models_config_repository = nullptr;
scope_guard models_repository_guard;
ExternalLoaderXMLConfigRepository * external_dictionaries_config_repository = nullptr;
@ -359,8 +355,6 @@ struct ContextSharedPart : boost::noncopyable
external_dictionaries_loader->enablePeriodicUpdates(false);
if (external_user_defined_executable_functions_loader)
external_user_defined_executable_functions_loader->enablePeriodicUpdates(false);
if (external_models_loader)
external_models_loader->enablePeriodicUpdates(false);
Session::shutdownNamedSessions();
@ -391,7 +385,6 @@ struct ContextSharedPart : boost::noncopyable
std::unique_ptr<EmbeddedDictionaries> delete_embedded_dictionaries;
std::unique_ptr<ExternalDictionariesLoader> delete_external_dictionaries_loader;
std::unique_ptr<ExternalUserDefinedExecutableFunctionsLoader> delete_external_user_defined_executable_functions_loader;
std::unique_ptr<ExternalModelsLoader> delete_external_models_loader;
std::unique_ptr<BackgroundSchedulePool> delete_buffer_flush_schedule_pool;
std::unique_ptr<BackgroundSchedulePool> delete_schedule_pool;
std::unique_ptr<BackgroundSchedulePool> delete_distributed_schedule_pool;
@ -430,7 +423,6 @@ struct ContextSharedPart : boost::noncopyable
delete_embedded_dictionaries = std::move(embedded_dictionaries);
delete_external_dictionaries_loader = std::move(external_dictionaries_loader);
delete_external_user_defined_executable_functions_loader = std::move(external_user_defined_executable_functions_loader);
delete_external_models_loader = std::move(external_models_loader);
delete_buffer_flush_schedule_pool = std::move(buffer_flush_schedule_pool);
delete_schedule_pool = std::move(schedule_pool);
delete_distributed_schedule_pool = std::move(distributed_schedule_pool);
@ -458,7 +450,6 @@ struct ContextSharedPart : boost::noncopyable
delete_embedded_dictionaries.reset();
delete_external_dictionaries_loader.reset();
delete_external_user_defined_executable_functions_loader.reset();
delete_external_models_loader.reset();
delete_ddl_worker.reset();
delete_buffer_flush_schedule_pool.reset();
delete_schedule_pool.reset();
@ -1476,48 +1467,6 @@ ExternalUserDefinedExecutableFunctionsLoader & Context::getExternalUserDefinedEx
return *shared->external_user_defined_executable_functions_loader;
}
const ExternalModelsLoader & Context::getExternalModelsLoader() const
{
return const_cast<Context *>(this)->getExternalModelsLoader();
}
ExternalModelsLoader & Context::getExternalModelsLoader()
{
std::lock_guard lock(shared->external_models_mutex);
return getExternalModelsLoaderUnlocked();
}
ExternalModelsLoader & Context::getExternalModelsLoaderUnlocked()
{
if (!shared->external_models_loader)
shared->external_models_loader =
std::make_unique<ExternalModelsLoader>(getGlobalContext());
return *shared->external_models_loader;
}
void Context::loadOrReloadModels(const Poco::Util::AbstractConfiguration & config)
{
auto patterns_values = getMultipleValuesFromConfig(config, "", "models_config");
std::unordered_set<std::string> patterns(patterns_values.begin(), patterns_values.end());
std::lock_guard lock(shared->external_models_mutex);
auto & external_models_loader = getExternalModelsLoaderUnlocked();
if (shared->external_models_config_repository)
{
shared->external_models_config_repository->updatePatterns(patterns);
external_models_loader.reloadConfig(shared->external_models_config_repository->getName());
return;
}
auto app_path = getPath();
auto config_path = getConfigRef().getString("config-file", "config.xml");
auto repository = std::make_unique<ExternalLoaderXMLConfigRepository>(app_path, config_path, patterns);
shared->external_models_config_repository = repository.get();
shared->models_repository_guard = external_models_loader.addConfigRepository(std::move(repository));
}
EmbeddedDictionaries & Context::getEmbeddedDictionariesImpl(const bool throw_on_error) const
{
std::lock_guard lock(shared->embedded_dictionaries_mutex);

View File

@ -53,7 +53,6 @@ class AccessRightsElements;
enum class RowPolicyFilterType;
class EmbeddedDictionaries;
class ExternalDictionariesLoader;
class ExternalModelsLoader;
class ExternalUserDefinedExecutableFunctionsLoader;
class InterserverCredentials;
using InterserverCredentialsPtr = std::shared_ptr<const InterserverCredentials>;
@ -645,19 +644,15 @@ public:
const EmbeddedDictionaries & getEmbeddedDictionaries() const;
const ExternalDictionariesLoader & getExternalDictionariesLoader() const;
const ExternalModelsLoader & getExternalModelsLoader() const;
const ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoader() const;
EmbeddedDictionaries & getEmbeddedDictionaries();
ExternalDictionariesLoader & getExternalDictionariesLoader();
ExternalDictionariesLoader & getExternalDictionariesLoaderUnlocked();
ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoader();
ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoaderUnlocked();
ExternalModelsLoader & getExternalModelsLoader();
ExternalModelsLoader & getExternalModelsLoaderUnlocked();
void tryCreateEmbeddedDictionaries(const Poco::Util::AbstractConfiguration & config) const;
void loadOrReloadDictionaries(const Poco::Util::AbstractConfiguration & config);
void loadOrReloadUserDefinedExecutableFunctions(const Poco::Util::AbstractConfiguration & config);
void loadOrReloadModels(const Poco::Util::AbstractConfiguration & config);
#if USE_NLP
SynonymsExtensions & getSynonymsExtensions() const;

View File

@ -1,41 +0,0 @@
#include <Interpreters/ExternalModelsLoader.h>
#include <Interpreters/Context.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INVALID_CONFIG_PARAMETER;
}
ExternalModelsLoader::ExternalModelsLoader(ContextPtr context_)
: ExternalLoader("external model", &Poco::Logger::get("ExternalModelsLoader")), WithContext(context_)
{
setConfigSettings({"model", "name", {}, {}});
enablePeriodicUpdates(true);
}
std::shared_ptr<const IExternalLoadable> ExternalModelsLoader::create(
const std::string & name, const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix, const std::string & /* repository_name */) const
{
String type = config.getString(config_prefix + ".type");
ExternalLoadableLifetime lifetime(config, config_prefix + ".lifetime");
/// TODO: add models factory.
if (type == "catboost")
{
return std::make_unique<CatBoostModel>(
name, config.getString(config_prefix + ".path"),
getContext()->getConfigRef().getString("catboost_dynamic_library_path"),
lifetime
);
}
else
{
throw Exception("Unknown model type: " + type, ErrorCodes::INVALID_CONFIG_PARAMETER);
}
}
}

View File

@ -1,40 +0,0 @@
#pragma once
#include <Interpreters/CatBoostModel.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/ExternalLoader.h>
#include <Common/logger_useful.h>
#include <memory>
namespace DB
{
/// Manages user-defined models.
class ExternalModelsLoader : public ExternalLoader, WithContext
{
public:
using ModelPtr = std::shared_ptr<const IMLModel>;
/// Models will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
explicit ExternalModelsLoader(ContextPtr context_);
ModelPtr getModel(const std::string & model_name) const
{
return std::static_pointer_cast<const IMLModel>(load(model_name));
}
void reloadModel(const std::string & model_name) const
{
loadOrReload(model_name);
}
protected:
LoadablePtr create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix, const std::string & repository_name) const override;
friend class StorageSystemModels;
};
}

View File

@ -12,7 +12,6 @@
#include <Interpreters/Context.h>
#include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/ExternalModelsLoader.h>
#include <Interpreters/ExternalUserDefinedExecutableFunctionsLoader.h>
#include <Interpreters/EmbeddedDictionaries.h>
#include <Interpreters/ActionLocksManager.h>
@ -36,6 +35,7 @@
#include <Interpreters/ProcessorsProfileLog.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Interpreters/TransactionLog.h>
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
#include <Access/ContextAccess.h>
#include <Access/Common/AllowedClientHosts.h>
#include <Databases/IDatabase.h>
@ -387,17 +387,15 @@ BlockIO InterpreterSystemQuery::execute()
case Type::RELOAD_MODEL:
{
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
auto & external_models_loader = system_context->getExternalModelsLoader();
external_models_loader.reloadModel(query.target_model);
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), query.target_model);
bridge_helper->removeModel();
break;
}
case Type::RELOAD_MODELS:
{
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
auto & external_models_loader = system_context->getExternalModelsLoader();
external_models_loader.reloadAllTriedToLoad();
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext());
bridge_helper->removeAllModels();
break;
}
case Type::RELOAD_FUNCTION:

View File

@ -1,11 +1,11 @@
#include <Storages/System/StorageSystemModels.h>
#include <Common/ExternalModelInfo.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeEnum.h>
#include <Interpreters/Context.h>
#include <Interpreters/ExternalModelsLoader.h>
#include <Interpreters/CatBoostModel.h>
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
namespace DB
@ -14,45 +14,24 @@ namespace DB
NamesAndTypesList StorageSystemModels::getNamesAndTypes()
{
return {
{ "name", std::make_shared<DataTypeString>() },
{ "status", std::make_shared<DataTypeEnum8>(getStatusEnumAllPossibleValues()) },
{ "origin", std::make_shared<DataTypeString>() },
{ "model_path", std::make_shared<DataTypeString>() },
{ "type", std::make_shared<DataTypeString>() },
{ "loading_start_time", std::make_shared<DataTypeDateTime>() },
{ "loading_duration", std::make_shared<DataTypeFloat32>() },
//{ "creation_time", std::make_shared<DataTypeDateTime>() },
{ "last_exception", std::make_shared<DataTypeString>() },
};
}
void StorageSystemModels::fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo &) const
{
const auto & external_models_loader = context->getExternalModelsLoader();
auto load_results = external_models_loader.getLoadResults();
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(context);
ExternalModelInfos infos = bridge_helper->listModels();
for (const auto & load_result : load_results)
for (const auto & info : infos)
{
res_columns[0]->insert(load_result.name);
res_columns[1]->insert(static_cast<Int8>(load_result.status));
res_columns[2]->insert(load_result.config ? load_result.config->path : "");
if (load_result.object)
{
const auto model_ptr = std::static_pointer_cast<const IMLModel>(load_result.object);
res_columns[3]->insert(model_ptr->getTypeName());
}
else
{
res_columns[3]->insertDefault();
}
res_columns[4]->insert(static_cast<UInt64>(std::chrono::system_clock::to_time_t(load_result.loading_start_time)));
res_columns[5]->insert(std::chrono::duration_cast<std::chrono::duration<float>>(load_result.loading_duration).count());
if (load_result.exception)
res_columns[6]->insert(getExceptionMessage(load_result.exception, false));
else
res_columns[6]->insertDefault();
res_columns[0]->insert(info.model_path);
res_columns[1]->insert(info.model_type);
res_columns[2]->insert(static_cast<UInt64>(std::chrono::system_clock::to_time_t(info.loading_start_time)));
res_columns[3]->insert(std::chrono::duration_cast<std::chrono::duration<float>>(info.loading_duration).count());
}
}

View File

@ -763,7 +763,6 @@
"MINUTE"
"MM"
"mod"
"modelEvaluate"
"MODIFY"
"MODIFY COLUMN"
"MODIFY ORDER BY"

View File

@ -469,7 +469,6 @@
"subtractSeconds"
"alphaTokens"
"negate"
"modelEvaluate"
"file"
"roundAge"
"MACStringToOUI"

View File

@ -0,0 +1,3 @@
<clickhouse>
<catboost_lib_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_lib_path>
</clickhouse>

View File

@ -0,0 +1,402 @@
import os
import sys
import time
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance(
"instance", stay_alive=True, main_configs=["config/models_config.xml"]
)
@pytest.fixture(scope="module")
def ch_cluster():
try:
cluster.start()
os.system(
"docker cp {local} {cont_id}:{dist}".format(
local=os.path.join(SCRIPT_DIR, "model/."),
cont_id=instance.docker_id,
dist="/etc/clickhouse-server/model",
)
)
instance.restart_clickhouse()
yield cluster
finally:
cluster.shutdown()
# ---------------------------------------------------------------------------
# simple_model.bin has 2 float features and 9 categorical features
def testConstantFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
expected = "-1.930268705869267\n"
assert result == expected
def testNonConstantFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
instance.query("DROP TABLE IF EXISTS T;")
instance.query(
"CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;"
)
instance.query("INSERT INTO T VALUES(0, 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11) from T;"
)
expected = "-1.930268705869267\n"
assert result == expected
instance.query("DROP TABLE IF EXISTS T;")
def testModelPathIsNotAConstString(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error(
"select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert (
"Illegal type UInt8 of first argument of function catboostEvaluate, expected a string"
in err
)
instance.query("DROP TABLE IF EXISTS T;")
instance.query("CREATE TABLE T(ID UInt32, A String) ENGINE MergeTree ORDER BY ID")
instance.query("INSERT INTO T VALUES(0, 'test');")
err = instance.query_and_get_error(
"select catboostEvaluate(A, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) FROM T;"
)
assert (
"First argument of function catboostEvaluate must be a constant string" in err
)
instance.query("DROP TABLE IF EXISTS T;")
def testWrongNumberOfFeatureArguments(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');"
)
assert "Function catboostEvaluate expects at least 2 arguments" in err
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2);"
)
assert (
"Number of columns is different with number of features: columns size 2 float features size 2 + cat features size 9"
in err
)
def testFloatFeatureMustBeNumeric(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert "Column 1 should be numeric to make float feature" in err
def testCategoricalFeatureMustBeNumericOrString(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);"
)
assert "Column 7 should be numeric or string" in err
def testOnLowCardinalityFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
# same but on domain-compressed data
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));"
)
expected = "-1.930268705869267\n"
assert result == expected
def testOnNullableFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));"
)
expected = "-1.930268705869267\n"
assert result == expected
# Actual NULLs are disallowed
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL));"
)
assert "Column 0 should be numeric to make float feature" in err
def testInvalidLibraryPath(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
# temporarily move library elsewhere
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/libcatboostmodel.so /etc/clickhouse-server/model/nonexistant.so",
]
)
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert (
"Can't load library /etc/clickhouse-server/model/libcatboostmodel.so: file doesn't exist"
in err
)
# restore
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/nonexistant.so /etc/clickhouse-server/model/libcatboostmodel.so",
]
)
def testInvalidModelPath(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error(
"select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert "Can't load model : file doesn't exist" in err
err = instance.query_and_get_error(
"select catboostEvaluate('model_non_existant.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert "Can't load model model_non_existant.bin: file doesn't exist" in err
def testRecoveryAfterCrash(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
expected = "-1.930268705869267\n"
assert result == expected
instance.exec_in_container(
["bash", "-c", "kill -9 `pidof clickhouse-library-bridge`"], user="root"
)
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert result == expected
# ---------------------------------------------------------------------------
# amazon_model.bin has 0 float features and 9 categorical features
def testAmazonModelSingleRow(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
)
expected = "0.7774665009089274\n"
assert result == expected
def testAmazonModelManyRows(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query("drop table if exists amazon")
result = instance.query(
"create table amazon ( DATE Date materialized today(), ACTION UInt8, RESOURCE UInt32, MGR_ID UInt32, ROLE_ROLLUP_1 UInt32, ROLE_ROLLUP_2 UInt32, ROLE_DEPTNAME UInt32, ROLE_TITLE UInt32, ROLE_FAMILY_DESC UInt32, ROLE_FAMILY UInt32, ROLE_CODE UInt32) engine = MergeTree order by DATE"
)
result = instance.query(
"insert into amazon select number % 256, number, number, number, number, number, number, number, number, number from numbers(7500)"
)
# First compute prediction, then as a very crude way to fingerprint and compare the result: sum and floor
# (the focus is to test that the exchange of large result sets between the server and the bridge works)
result = instance.query(
"SELECT floor(sum(catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', RESOURCE, MGR_ID, ROLE_ROLLUP_1, ROLE_ROLLUP_2, ROLE_DEPTNAME, ROLE_TITLE, ROLE_FAMILY_DESC, ROLE_FAMILY, ROLE_CODE))) FROM amazon"
)
expected = "5834\n"
assert result == expected
result = instance.query("drop table if exists amazon")
def testModelUpdate(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
result = instance.query(query)
expected = "-1.930268705869267\n"
assert result == expected
# simulate an update of the model: temporarily move the amazon model in place of the simple model
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/simple_model.bin.bak",
]
)
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/amazon_model.bin /etc/clickhouse-server/model/simple_model.bin",
]
)
# unload simple model
result = instance.query(
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
)
# load the simple-model-camouflaged amazon model
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
)
expected = "0.7774665009089274\n"
assert result == expected
# restore
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/amazon_model.bin",
]
)
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin",
]
)
def testSystemModelsAndModelRefresh(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
# check model system view
result = instance.query("select * from system.models")
expected = ""
assert result == expected
# load simple model
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
expected = "-1.930268705869267\n"
assert result == expected
# check model system view with one model loaded
result = instance.query("select * from system.models")
assert result.count("\n") == 1
expected = "/etc/clickhouse-server/model/simple_model.bin"
assert expected in result
# load amazon model
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
)
expected = "0.7774665009089274\n"
assert result == expected
# check model system view with one model loaded
result = instance.query("select * from system.models")
assert result.count("\n") == 2
expected = "/etc/clickhouse-server/model/simple_model.bin"
assert expected in result
expected = "/etc/clickhouse-server/model/amazon_model.bin"
assert expected in result
# unload simple model
result = instance.query(
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
)
# check model system view, it should not display the removed model
result = instance.query("select * from system.models")
assert result.count("\n") == 1
expected = "/etc/clickhouse-server/model/amazon_model.bin"
assert expected in result

View File

@ -1,3 +0,0 @@
<clickhouse>
<catboost_dynamic_library_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_dynamic_library_path>
</clickhouse>

View File

@ -1,2 +0,0 @@
<clickhouse>
</clickhouse>

View File

@ -1,8 +0,0 @@
<models>
<model>
<type>catboost</type>
<name>model1</name>
<path>/etc/clickhouse-server/model/model.bin</path>
<lifetime>0</lifetime>
</model>
</models>

View File

@ -1,8 +0,0 @@
<models>
<model>
<type>catboost</type>
<name>model2</name>
<path>/etc/clickhouse-server/model/model.bin</path>
<lifetime>0</lifetime>
</model>
</models>

View File

@ -1,77 +0,0 @@
import os
import sys
import time
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance(
"node",
stay_alive=True,
main_configs=["config/models_config.xml", "config/catboost_lib.xml"],
)
def copy_file_to_container(local_path, dist_path, container_id):
os.system(
"docker cp {local} {cont_id}:{dist}".format(
local=local_path, cont_id=container_id, dist=dist_path
)
)
config = """<clickhouse>
<models_config>/etc/clickhouse-server/model/{model_config}</models_config>
</clickhouse>"""
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()
copy_file_to_container(
os.path.join(SCRIPT_DIR, "model/."),
"/etc/clickhouse-server/model",
node.docker_id,
)
node.restart_clickhouse()
yield cluster
finally:
cluster.shutdown()
def change_config(model_config):
node.replace_config(
"/etc/clickhouse-server/config.d/models_config.xml",
config.format(model_config=model_config),
)
node.query("SYSTEM RELOAD CONFIG;")
def test(started_cluster):
if node.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
# Set config with the path to the first model.
change_config("model_config.xml")
node.query("SELECT modelEvaluate('model1', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
# Change path to the second model in config.
change_config("model_config2.xml")
# Check that the new model is loaded.
node.query("SELECT modelEvaluate('model2', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
# Check that the old model was unloaded.
node.query_and_get_error(
"SELECT modelEvaluate('model1', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)

View File

@ -1,4 +0,0 @@
<clickhouse>
<catboost_dynamic_library_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_dynamic_library_path>
<models_config>/etc/clickhouse-server/model/model_config.xml</models_config>
</clickhouse>

View File

@ -1,8 +0,0 @@
<models>
<model>
<type>catboost</type>
<name>titanic</name>
<path>/etc/clickhouse-server/model/model.bin</path>
<lifetime>0</lifetime>
</model>
</models>

View File

@ -1,48 +0,0 @@
import os
import sys
import time
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance(
"node", stay_alive=True, main_configs=["config/models_config.xml"]
)
def copy_file_to_container(local_path, dist_path, container_id):
os.system(
"docker cp {local} {cont_id}:{dist}".format(
local=local_path, cont_id=container_id, dist=dist_path
)
)
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()
copy_file_to_container(
os.path.join(SCRIPT_DIR, "model/."),
"/etc/clickhouse-server/model",
node.docker_id,
)
node.restart_clickhouse()
yield cluster
finally:
cluster.shutdown()
def test(started_cluster):
if node.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
node.query("select modelEvaluate('titanic', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);")

View File

@ -1,3 +0,0 @@
<clickhouse>
<catboost_dynamic_library_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_dynamic_library_path>
</clickhouse>

View File

@ -1,3 +0,0 @@
<clickhouse>
<models_config>/etc/clickhouse-server/model/model_config.xml</models_config>
</clickhouse>

View File

@ -1,8 +0,0 @@
<models>
<model>
<type>catboost</type>
<name>model</name>
<path>/etc/clickhouse-server/model/model.cbm</path>
<lifetime>0</lifetime>
</model>
</models>

View File

@ -1,132 +0,0 @@
import os
import sys
import time
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance(
"node",
stay_alive=True,
main_configs=["config/models_config.xml", "config/catboost_lib.xml"],
)
def copy_file_to_container(local_path, dist_path, container_id):
os.system(
"docker cp {local} {cont_id}:{dist}".format(
local=local_path, cont_id=container_id, dist=dist_path
)
)
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()
copy_file_to_container(
os.path.join(SCRIPT_DIR, "model/."),
"/etc/clickhouse-server/model",
node.docker_id,
)
node.query("CREATE TABLE binary (x UInt64, y UInt64) ENGINE = TinyLog()")
node.query("INSERT INTO binary VALUES (1, 1), (1, 0), (0, 1), (0, 0)")
node.restart_clickhouse()
yield cluster
finally:
cluster.shutdown()
def test_model_reload(started_cluster):
if node.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
node.exec_in_container(
["bash", "-c", "rm -f /etc/clickhouse-server/model/model.cbm"]
)
node.exec_in_container(
[
"bash",
"-c",
"ln /etc/clickhouse-server/model/conjunction.cbm /etc/clickhouse-server/model/model.cbm",
]
)
node.query("SYSTEM RELOAD MODEL model")
result = node.query(
"""
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
SELECT if(probability > 0.5, 1, 0) FROM binary;
"""
)
assert result == "1\n0\n0\n0\n"
node.exec_in_container(["bash", "-c", "rm /etc/clickhouse-server/model/model.cbm"])
node.exec_in_container(
[
"bash",
"-c",
"ln /etc/clickhouse-server/model/disjunction.cbm /etc/clickhouse-server/model/model.cbm",
]
)
node.query("SYSTEM RELOAD MODEL model")
result = node.query(
"""
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
SELECT if(probability > 0.5, 1, 0) FROM binary;
"""
)
assert result == "1\n1\n1\n0\n"
def test_models_reload(started_cluster):
if node.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
node.exec_in_container(
["bash", "-c", "rm -f /etc/clickhouse-server/model/model.cbm"]
)
node.exec_in_container(
[
"bash",
"-c",
"ln /etc/clickhouse-server/model/conjunction.cbm /etc/clickhouse-server/model/model.cbm",
]
)
node.query("SYSTEM RELOAD MODELS")
result = node.query(
"""
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
SELECT if(probability > 0.5, 1, 0) FROM binary;
"""
)
assert result == "1\n0\n0\n0\n"
node.exec_in_container(["bash", "-c", "rm /etc/clickhouse-server/model/model.cbm"])
node.exec_in_container(
[
"bash",
"-c",
"ln /etc/clickhouse-server/model/disjunction.cbm /etc/clickhouse-server/model/model.cbm",
]
)
node.query("SYSTEM RELOAD MODELS")
result = node.query(
"""
WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability
SELECT if(probability > 0.5, 1, 0) FROM binary;
"""
)
assert result == "1\n1\n1\n0\n"

View File

@ -16,7 +16,7 @@ function run_selects()
{
thread_num=$1
readarray -t tables_arr < <(${CLICKHOUSE_CLIENT} -q "SELECT database || '.' || name FROM system.tables
WHERE database in ('system', 'information_schema', 'INFORMATION_SCHEMA') and name!='zookeeper' and name!='merge_tree_metadata_cache'
WHERE database in ('system', 'information_schema', 'INFORMATION_SCHEMA') and name!='zookeeper' and name!='merge_tree_metadata_cache' and name!='models'
AND sipHash64(name || toString($RAND)) % $THREADS = $thread_num")
for t in "${tables_arr[@]}"

View File

@ -364,18 +364,6 @@ CREATE TABLE system.metrics
)
ENGINE = SystemMetrics
COMMENT 'SYSTEM TABLE is built on the fly.'
CREATE TABLE system.models
(
`name` String,
`status` Enum8('NOT_LOADED' = 0, 'LOADED' = 1, 'FAILED' = 2, 'LOADING' = 3, 'FAILED_AND_RELOADING' = 4, 'LOADED_AND_RELOADING' = 5, 'NOT_EXIST' = 6),
`origin` String,
`type` String,
`loading_start_time` DateTime,
`loading_duration` Float32,
`last_exception` String
)
ENGINE = SystemModels
COMMENT 'SYSTEM TABLE is built on the fly.'
CREATE TABLE system.mutations
(
`database` String,

View File

@ -45,7 +45,6 @@ show create table macros format TSVRaw;
show create table merge_tree_settings format TSVRaw;
show create table merges format TSVRaw;
show create table metrics format TSVRaw;
show create table models format TSVRaw;
show create table mutations format TSVRaw;
show create table numbers format TSVRaw;
show create table numbers_mt format TSVRaw;

View File

@ -1,2 +0,0 @@
-- This model does not exist:
SELECT modelEvaluate('hello', 1, 2, 3); -- { serverError 36 }

View File

@ -192,6 +192,7 @@ caseWithExpr
caseWithExpression
caseWithoutExpr
caseWithoutExpression
catboostEvaluate
cbrt
ceil
char
@ -475,7 +476,6 @@ min2
minSampleSizeContinous
minSampleSizeConversion
minus
modelEvaluate
modulo
moduloLegacy
moduloOrZero