mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-25 00:52:02 +00:00
feat: implement catboost in library-bridge
This commit moves the catboost model evaluation out of the server process into the library-bridge binary. This serves two goals: On the one hand, crashes / memory corruptions of the catboost library no longer affect the server. On the other hand, we can forbid loading dynamic libraries in the server (catboost was the last consumer of this functionality), thus improving security. SQL syntax: SELECT catboostEvaluate('/path/to/model.bin', FEAT_1, ..., FEAT_N) > 0 AS prediction, ACTION AS target FROM amazon_train LIMIT 10 Required configuration: <catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path> *** Implementation Details *** The internal protocol between the server and the library-bridge is simple: - HTTP GET on path "/extdict_ping": A ping, used during the handshake to check if the library-bridge runs. - HTTP POST on path "extdict_request" (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler and finally sends the result back to the server. Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than once and if "model.bin" was updated in between. (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on. Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library handler for the given model path is expected to be already loaded by Step (1). Fixes #27870
This commit is contained in:
parent
68808858a5
commit
60f9f6855d
4
.gitignore
vendored
4
.gitignore
vendored
@ -58,6 +58,10 @@ cmake_install.cmake
|
||||
CTestTestfile.cmake
|
||||
*.a
|
||||
*.o
|
||||
*.so
|
||||
*.dll
|
||||
*.lib
|
||||
*.dylib
|
||||
cmake-build-*
|
||||
|
||||
# Python cache
|
||||
|
@ -1823,6 +1823,36 @@ Result:
|
||||
Evaluate external model.
|
||||
Accepts a model name and model arguments. Returns Float64.
|
||||
|
||||
## catboostEvaluate(path_to_model, feature_1, feature_2, …, feature_n)
|
||||
|
||||
Evaluate external catboost model. [CatBoost](https://catboost.ai) is an open-source gradient boosting library developed by Yandex for machine learing.
|
||||
Accepts a path to a catboost model and model arguments (features). Returns Float64.
|
||||
|
||||
``` sql
|
||||
SELECT feat1, ..., feat_n, catboostEvaluate('/path/to/model.bin', feat_1, ..., feat_n) AS prediction
|
||||
FROM data_table
|
||||
```
|
||||
|
||||
**Prerequisites**
|
||||
|
||||
1. Build the catboost evaluation library
|
||||
|
||||
Before evaluating catboost models, the `libcatboostmodel.<so|dylib>` library must be made available. See [CatBoost documentation](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html) how to compile it.
|
||||
|
||||
Next, specify the path to `libcatboostmodel.<so|dylib>` in the clickhouse configuration:
|
||||
|
||||
``` xml
|
||||
<clickhouse>
|
||||
...
|
||||
<catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path>
|
||||
...
|
||||
</clickhouse>
|
||||
```
|
||||
|
||||
2. Train a catboost model using libcatboost
|
||||
|
||||
See [Training and applying models](https://catboost.ai/docs/features/training.html#training) for how to train catboost models from a training data set.
|
||||
|
||||
## throwIf(x\[, message\[, error_code\]\])
|
||||
|
||||
Throw an exception if the argument is non zero.
|
||||
|
@ -54,7 +54,7 @@ else ()
|
||||
endif ()
|
||||
|
||||
if (NOT USE_MUSL)
|
||||
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to Library dictionary source" ${ENABLE_CLICKHOUSE_ALL})
|
||||
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to external dynamically loaded libraries" ${ENABLE_CLICKHOUSE_ALL})
|
||||
endif ()
|
||||
|
||||
# https://presentations.clickhouse.com/matemarketing_2020/
|
||||
|
@ -1,6 +1,8 @@
|
||||
include(${ClickHouse_SOURCE_DIR}/cmake/split_debug_symbols.cmake)
|
||||
|
||||
set (CLICKHOUSE_LIBRARY_BRIDGE_SOURCES
|
||||
CatBoostLibraryHandler.cpp
|
||||
CatBoostLibraryHandlerFactory.cpp
|
||||
ExternalDictionaryLibraryAPI.cpp
|
||||
ExternalDictionaryLibraryHandler.cpp
|
||||
ExternalDictionaryLibraryHandlerFactory.cpp
|
||||
|
49
programs/library-bridge/CatBoostLibraryAPI.h
Normal file
49
programs/library-bridge/CatBoostLibraryAPI.h
Normal file
@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
|
||||
// Function pointer typedefs and names of libcatboost.so functions used by ClickHouse
|
||||
struct CatBoostLibraryAPI
|
||||
{
|
||||
using ModelCalcerHandle = void;
|
||||
|
||||
using ModelCalcerCreateFunc = ModelCalcerHandle * (*)();
|
||||
static constexpr const char * ModelCalcerCreateName = "ModelCalcerCreate";
|
||||
|
||||
using ModelCalcerDeleteFunc = void (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * ModelCalcerDeleteName = "ModelCalcerDelete";
|
||||
|
||||
using GetErrorStringFunc = const char * (*)();
|
||||
static constexpr const char * GetErrorStringName = "GetErrorString";
|
||||
|
||||
using LoadFullModelFromFileFunc = bool (*)(ModelCalcerHandle *, const char *);
|
||||
static constexpr const char * LoadFullModelFromFileName = "LoadFullModelFromFile";
|
||||
|
||||
using CalcModelPredictionFlatFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, double *, size_t);
|
||||
static constexpr const char * CalcModelPredictionFlatName = "CalcModelPredictionFlat";
|
||||
|
||||
using CalcModelPredictionFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const char ***, size_t, double *, size_t);
|
||||
static constexpr const char * CalcModelPredictionName = "CalcModelPrediction";
|
||||
|
||||
using CalcModelPredictionWithHashedCatFeaturesFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const int **, size_t, double *, size_t);
|
||||
static constexpr const char * CalcModelPredictionWithHashedCatFeaturesName = "CalcModelPredictionWithHashedCatFeatures";
|
||||
|
||||
using GetStringCatFeatureHashFunc = int (*)(const char *, size_t);
|
||||
static constexpr const char * GetStringCatFeatureHashName = "GetStringCatFeatureHash";
|
||||
|
||||
using GetIntegerCatFeatureHashFunc = int (*)(uint64_t);
|
||||
static constexpr const char * GetIntegerCatFeatureHashName = "GetIntegerCatFeatureHash";
|
||||
|
||||
using GetFloatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetFloatFeaturesCountName = "GetFloatFeaturesCount";
|
||||
|
||||
using GetCatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetCatFeaturesCountName = "GetCatFeaturesCount";
|
||||
|
||||
using GetTreeCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetTreeCountName = "GetTreeCount";
|
||||
|
||||
using GetDimensionsCountFunc = size_t (*)(ModelCalcerHandle *);
|
||||
static constexpr const char * GetDimensionsCountName = "GetDimensionsCount";
|
||||
};
|
376
programs/library-bridge/CatBoostLibraryHandler.cpp
Normal file
376
programs/library-bridge/CatBoostLibraryHandler.cpp
Normal file
@ -0,0 +1,376 @@
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int CANNOT_APPLY_CATBOOST_MODEL;
|
||||
extern const int CANNOT_LOAD_CATBOOST_MODEL;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib)
|
||||
{
|
||||
ModelCalcerCreate = lib.get<CatBoostLibraryAPI::ModelCalcerCreateFunc>(CatBoostLibraryAPI::ModelCalcerCreateName);
|
||||
ModelCalcerDelete = lib.get<CatBoostLibraryAPI::ModelCalcerDeleteFunc>(CatBoostLibraryAPI::ModelCalcerDeleteName);
|
||||
GetErrorString = lib.get<CatBoostLibraryAPI::GetErrorStringFunc>(CatBoostLibraryAPI::GetErrorStringName);
|
||||
LoadFullModelFromFile = lib.get<CatBoostLibraryAPI::LoadFullModelFromFileFunc>(CatBoostLibraryAPI::LoadFullModelFromFileName);
|
||||
CalcModelPredictionFlat = lib.get<CatBoostLibraryAPI::CalcModelPredictionFlatFunc>(CatBoostLibraryAPI::CalcModelPredictionFlatName);
|
||||
CalcModelPrediction = lib.get<CatBoostLibraryAPI::CalcModelPredictionFunc>(CatBoostLibraryAPI::CalcModelPredictionName);
|
||||
CalcModelPredictionWithHashedCatFeatures = lib.get<CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc>(CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesName);
|
||||
GetStringCatFeatureHash = lib.get<CatBoostLibraryAPI::GetStringCatFeatureHashFunc>(CatBoostLibraryAPI::GetStringCatFeatureHashName);
|
||||
GetIntegerCatFeatureHash = lib.get<CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc>(CatBoostLibraryAPI::GetIntegerCatFeatureHashName);
|
||||
GetFloatFeaturesCount = lib.get<CatBoostLibraryAPI::GetFloatFeaturesCountFunc>(CatBoostLibraryAPI::GetFloatFeaturesCountName);
|
||||
GetCatFeaturesCount = lib.get<CatBoostLibraryAPI::GetCatFeaturesCountFunc>(CatBoostLibraryAPI::GetCatFeaturesCountName);
|
||||
GetTreeCount = lib.tryGet<CatBoostLibraryAPI::GetTreeCountFunc>(CatBoostLibraryAPI::GetTreeCountName);
|
||||
GetDimensionsCount = lib.tryGet<CatBoostLibraryAPI::GetDimensionsCountFunc>(CatBoostLibraryAPI::GetDimensionsCountName);
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::CatBoostLibraryHandler(
|
||||
const std::string & library_path,
|
||||
const std::string & model_path)
|
||||
: library(std::make_shared<SharedLibrary>(library_path))
|
||||
, api(*library)
|
||||
{
|
||||
model_calcer_handle = api.ModelCalcerCreate();
|
||||
|
||||
if (!api.LoadFullModelFromFile(model_calcer_handle, model_path.c_str()))
|
||||
{
|
||||
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
|
||||
"Cannot load CatBoost model: {}", api.GetErrorString());
|
||||
}
|
||||
|
||||
float_features_count = api.GetFloatFeaturesCount(model_calcer_handle);
|
||||
cat_features_count = api.GetCatFeaturesCount(model_calcer_handle);
|
||||
|
||||
tree_count = 1;
|
||||
if (api.GetDimensionsCount)
|
||||
tree_count = api.GetDimensionsCount(model_calcer_handle);
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::~CatBoostLibraryHandler()
|
||||
{
|
||||
api.ModelCalcerDelete(model_calcer_handle);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
template <typename T>
|
||||
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column->size();
|
||||
FieldVisitorConvertToNumber<T> visitor;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
/// TODO: Replace with column visitor.
|
||||
Field field;
|
||||
column->get(i, field);
|
||||
*buffer = applyVisitor(visitor, field);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
|
||||
PODArray<char> placeFixedStringColumn(const ColumnFixedString & column, const char ** buffer, size_t features_count)
|
||||
{
|
||||
size_t size = column.size();
|
||||
size_t str_size = column.getN();
|
||||
PODArray<char> data(size * (str_size + 1));
|
||||
char * data_ptr = data.data();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto ref = column.getDataAt(i);
|
||||
memcpy(data_ptr, ref.data, ref.size);
|
||||
data_ptr[ref.size] = 0;
|
||||
*buffer = data_ptr;
|
||||
data_ptr += ref.size + 1;
|
||||
buffer += features_count;
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
|
||||
template <typename T>
|
||||
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const T** buffer)
|
||||
{
|
||||
if (size == 0)
|
||||
return nullptr;
|
||||
|
||||
size_t column_size = columns[offset]->size();
|
||||
auto data_column = ColumnVector<T>::create(size * column_size);
|
||||
T * data = data_column->getData().data();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (column->isNumeric())
|
||||
placeColumnAsNumber(column, data + i, size);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*buffer = data;
|
||||
++buffer;
|
||||
data += size;
|
||||
}
|
||||
|
||||
return data_column;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns data which was used for fixed string columns.
|
||||
/// Buffer should contains column->size() values, each value contains size strings.
|
||||
std::vector<PODArray<char>> placeStringColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer)
|
||||
{
|
||||
if (size == 0)
|
||||
return {};
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
|
||||
placeStringColumn(*column_string, buffer + i, size);
|
||||
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
|
||||
else
|
||||
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
|
||||
void fillCatFeaturesBuffer(
|
||||
const char *** cat_features, const char ** buffer,
|
||||
size_t column_size, size_t cat_features_count)
|
||||
{
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*cat_features = buffer;
|
||||
++cat_features;
|
||||
buffer += cat_features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Calc hash for string cat feature at ps positions.
|
||||
template <typename Column>
|
||||
void calcStringHashes(const Column * column, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
|
||||
{
|
||||
size_t column_size = column->size();
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
auto ref = column->getDataAt(j);
|
||||
const_cast<int *>(*buffer)[ps] = api.GetStringCatFeatureHash(ref.data, ref.size);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
|
||||
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
|
||||
{
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
const_cast<int *>(*buffer)[ps] = api.GetIntegerCatFeatureHash((*buffer)[ps]);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// buffer contains column->size() rows and size columns.
|
||||
/// For int cat features calc hash inplace.
|
||||
/// For string cat features calc hash from column rows.
|
||||
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
|
||||
{
|
||||
if (size == 0)
|
||||
return;
|
||||
size_t column_size = columns[offset]->size();
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const auto * column = columns[offset + i];
|
||||
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
|
||||
calcStringHashes(column_string, i, buffer, api);
|
||||
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
calcStringHashes(column_fixed_string, i, buffer, api);
|
||||
else
|
||||
calcIntHashes(column_size, i, buffer, api);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
|
||||
/// * CalcModelPredictionFlat if no cat features
|
||||
/// * CalcModelPrediction if all cat features are strings
|
||||
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
|
||||
ColumnFloat64::MutablePtr CatBoostLibraryHandler::evalImpl(
|
||||
const ColumnRawPtrs & columns,
|
||||
bool cat_features_are_strings) const
|
||||
{
|
||||
std::string error_msg = "Error occurred while applying CatBoost model: ";
|
||||
size_t column_size = columns.front()->size();
|
||||
|
||||
auto result = ColumnFloat64::create(column_size * tree_count);
|
||||
auto * result_buf = result->getData().data();
|
||||
|
||||
if (!column_size)
|
||||
return result;
|
||||
|
||||
/// Prepare float features.
|
||||
PODArray<const float *> float_features(column_size);
|
||||
auto * float_features_buf = float_features.data();
|
||||
/// Store all float data into single column. float_features is a list of pointers to it.
|
||||
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
|
||||
|
||||
if (cat_features_count == 0)
|
||||
{
|
||||
if (!api.CalcModelPredictionFlat(model_calcer_handle, column_size,
|
||||
float_features_buf, float_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
|
||||
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Prepare cat features.
|
||||
if (cat_features_are_strings)
|
||||
{
|
||||
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
|
||||
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
|
||||
PODArray<const char **> cat_features(column_size);
|
||||
auto * cat_features_buf = cat_features.data();
|
||||
|
||||
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
|
||||
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
|
||||
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
|
||||
cat_features_count, cat_features_holder.data());
|
||||
|
||||
if (!api.CalcModelPrediction(model_calcer_handle, column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
PODArray<const int *> cat_features(column_size);
|
||||
auto * cat_features_buf = cat_features.data();
|
||||
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
|
||||
cat_features_count, cat_features_buf);
|
||||
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf, api);
|
||||
if (!api.CalcModelPredictionWithHashedCatFeatures(
|
||||
model_calcer_handle, column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size * tree_count))
|
||||
{
|
||||
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t CatBoostLibraryHandler::getTreeCount() const
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
return tree_count;
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostLibraryHandler::evaluate(const ColumnRawPtrs & columns) const
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (columns.empty())
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got empty columns list for CatBoost model.");
|
||||
|
||||
if (columns.size() != float_features_count + cat_features_count)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Number of columns is different with number of features: columns size {} float features size {} + cat features size {}",
|
||||
columns.size(),
|
||||
float_features_count,
|
||||
cat_features_count);
|
||||
|
||||
for (size_t i = 0; i < float_features_count; ++i)
|
||||
{
|
||||
if (!columns[i]->isNumeric())
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric to make float feature.", i);
|
||||
}
|
||||
}
|
||||
|
||||
bool cat_features_are_strings = true;
|
||||
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
|
||||
{
|
||||
const auto * column = columns[i];
|
||||
if (column->isNumeric())
|
||||
{
|
||||
cat_features_are_strings = false;
|
||||
}
|
||||
else if (!(typeid_cast<const ColumnString *>(column)
|
||||
|| typeid_cast<const ColumnFixedString *>(column)))
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric or string.", i);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = evalImpl(columns, cat_features_are_strings);
|
||||
|
||||
if (tree_count == 1)
|
||||
return result;
|
||||
|
||||
size_t column_size = columns.front()->size();
|
||||
auto * result_buf = result->getData().data();
|
||||
|
||||
/// Multiple trees case. Copy data to several columns.
|
||||
MutableColumns mutable_columns(tree_count);
|
||||
std::vector<Float64 *> column_ptrs(tree_count);
|
||||
for (size_t i = 0; i < tree_count; ++i)
|
||||
{
|
||||
auto col = ColumnFloat64::create(column_size);
|
||||
column_ptrs[i] = col->getData().data();
|
||||
mutable_columns[i] = std::move(col);
|
||||
}
|
||||
|
||||
Float64 * data = result_buf;
|
||||
for (size_t row = 0; row < column_size; ++row)
|
||||
{
|
||||
for (size_t i = 0; i < tree_count; ++i)
|
||||
{
|
||||
*column_ptrs[i] = *data;
|
||||
++column_ptrs[i];
|
||||
++data;
|
||||
}
|
||||
}
|
||||
|
||||
return ColumnTuple::create(std::move(mutable_columns));
|
||||
}
|
||||
|
||||
}
|
71
programs/library-bridge/CatBoostLibraryHandler.h
Normal file
71
programs/library-bridge/CatBoostLibraryHandler.h
Normal file
@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
#include "CatBoostLibraryAPI.h"
|
||||
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Common/SharedLibrary.h>
|
||||
#include <base/defines.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Abstracts access to the CatBoost shared library.
|
||||
class CatBoostLibraryHandler
|
||||
{
|
||||
public:
|
||||
/// Holds pointers to CatBoost library functions
|
||||
struct APIHolder
|
||||
{
|
||||
explicit APIHolder(SharedLibrary & lib);
|
||||
|
||||
// NOLINTBEGIN(readability-identifier-naming)
|
||||
CatBoostLibraryAPI::ModelCalcerCreateFunc ModelCalcerCreate;
|
||||
CatBoostLibraryAPI::ModelCalcerDeleteFunc ModelCalcerDelete;
|
||||
CatBoostLibraryAPI::GetErrorStringFunc GetErrorString;
|
||||
CatBoostLibraryAPI::LoadFullModelFromFileFunc LoadFullModelFromFile;
|
||||
CatBoostLibraryAPI::CalcModelPredictionFlatFunc CalcModelPredictionFlat;
|
||||
CatBoostLibraryAPI::CalcModelPredictionFunc CalcModelPrediction;
|
||||
CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc CalcModelPredictionWithHashedCatFeatures;
|
||||
CatBoostLibraryAPI::GetStringCatFeatureHashFunc GetStringCatFeatureHash;
|
||||
CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc GetIntegerCatFeatureHash;
|
||||
CatBoostLibraryAPI::GetFloatFeaturesCountFunc GetFloatFeaturesCount;
|
||||
CatBoostLibraryAPI::GetCatFeaturesCountFunc GetCatFeaturesCount;
|
||||
CatBoostLibraryAPI::GetTreeCountFunc GetTreeCount;
|
||||
CatBoostLibraryAPI::GetDimensionsCountFunc GetDimensionsCount;
|
||||
// NOLINTEND(readability-identifier-naming)
|
||||
};
|
||||
|
||||
CatBoostLibraryHandler(
|
||||
const std::string & library_path,
|
||||
const std::string & model_path);
|
||||
|
||||
~CatBoostLibraryHandler();
|
||||
|
||||
size_t getTreeCount() const;
|
||||
|
||||
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
|
||||
|
||||
private:
|
||||
const SharedLibraryPtr library;
|
||||
const APIHolder api;
|
||||
|
||||
mutable std::mutex mutex;
|
||||
|
||||
CatBoostLibraryAPI::ModelCalcerHandle * model_calcer_handle TSA_GUARDED_BY(mutex) TSA_PT_GUARDED_BY(mutex);
|
||||
|
||||
size_t float_features_count TSA_GUARDED_BY(mutex);
|
||||
size_t cat_features_count TSA_GUARDED_BY(mutex);
|
||||
size_t tree_count TSA_GUARDED_BY(mutex);
|
||||
|
||||
ColumnFloat64::MutablePtr evalImpl(const ColumnRawPtrs & columns, bool cat_features_are_strings) const TSA_REQUIRES(mutex);
|
||||
};
|
||||
|
||||
using CatBoostLibraryHandlerPtr = std::shared_ptr<CatBoostLibraryHandler>;
|
||||
|
||||
}
|
49
programs/library-bridge/CatBoostLibraryHandlerFactory.cpp
Normal file
49
programs/library-bridge/CatBoostLibraryHandlerFactory.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
#include "CatBoostLibraryHandlerFactory.h"
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance()
|
||||
{
|
||||
static CatBoostLibraryHandlerFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::get(const String & model_path)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (auto handler = library_handlers.find(model_path); handler != library_handlers.end())
|
||||
return handler->second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::create(const String & library_path, const String & model_path)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (library_handlers.contains(model_path))
|
||||
{
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot load catboost library handler for model path {} because it exists already", model_path);
|
||||
return;
|
||||
}
|
||||
|
||||
library_handlers.emplace(std::make_pair(model_path, std::make_shared<CatBoostLibraryHandler>(library_path, model_path)));
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path {}.", model_path);
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::remove(const String & model_path)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
bool deleted = library_handlers.erase(model_path);
|
||||
if (!deleted)
|
||||
{
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path: {}", model_path);
|
||||
return;
|
||||
}
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path: {}", model_path);
|
||||
}
|
||||
|
||||
}
|
31
programs/library-bridge/CatBoostLibraryHandlerFactory.h
Normal file
31
programs/library-bridge/CatBoostLibraryHandlerFactory.h
Normal file
@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
|
||||
#include <base/defines.h>
|
||||
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class CatBoostLibraryHandlerFactory final : private boost::noncopyable
|
||||
{
|
||||
public:
|
||||
static CatBoostLibraryHandlerFactory & instance();
|
||||
|
||||
CatBoostLibraryHandlerPtr get(const String & model_path);
|
||||
|
||||
void create(const String & library_path, const String & model_path);
|
||||
|
||||
void remove(const String & model_path);
|
||||
|
||||
private:
|
||||
/// map: model path -> shared library handler
|
||||
std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
||||
}
|
@ -50,6 +50,6 @@ private:
|
||||
void * lib_data;
|
||||
};
|
||||
|
||||
using SharedLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
|
||||
using ExternalDictionaryLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
|
||||
|
||||
}
|
||||
|
@ -1,37 +1,40 @@
|
||||
#include "ExternalDictionaryLibraryHandlerFactory.h"
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
SharedLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const std::string & dictionary_id)
|
||||
ExternalDictionaryLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const String & dictionary_id)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
auto library_handler = library_handlers.find(dictionary_id);
|
||||
|
||||
if (library_handler != library_handlers.end())
|
||||
return library_handler->second;
|
||||
|
||||
if (auto handler = library_handlers.find(dictionary_id); handler != library_handlers.end())
|
||||
return handler->second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaryLibraryHandlerFactory::create(
|
||||
const std::string & dictionary_id,
|
||||
const std::string & library_path,
|
||||
const std::vector<std::string> & library_settings,
|
||||
const String & dictionary_id,
|
||||
const String & library_path,
|
||||
const std::vector<String> & library_settings,
|
||||
const Block & sample_block,
|
||||
const std::vector<std::string> & attributes_names)
|
||||
const std::vector<String> & attributes_names)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
if (!library_handlers.contains(dictionary_id))
|
||||
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
|
||||
else
|
||||
|
||||
if (library_handlers.contains(dictionary_id))
|
||||
{
|
||||
LOG_WARNING(&Poco::Logger::get("ExternalDictionaryLibraryHandlerFactory"), "Library handler with dictionary id {} already exists", dictionary_id);
|
||||
return;
|
||||
}
|
||||
|
||||
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
|
||||
}
|
||||
|
||||
|
||||
bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id)
|
||||
bool ExternalDictionaryLibraryHandlerFactory::clone(const String & from_dictionary_id, const String & to_dictionary_id)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
auto from_library_handler = library_handlers.find(from_dictionary_id);
|
||||
@ -45,7 +48,7 @@ bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dic
|
||||
}
|
||||
|
||||
|
||||
bool ExternalDictionaryLibraryHandlerFactory::remove(const std::string & dictionary_id)
|
||||
bool ExternalDictionaryLibraryHandlerFactory::remove(const String & dictionary_id)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
/// extDict_libDelete is called in destructor.
|
||||
|
@ -17,22 +17,22 @@ class ExternalDictionaryLibraryHandlerFactory final : private boost::noncopyable
|
||||
public:
|
||||
static ExternalDictionaryLibraryHandlerFactory & instance();
|
||||
|
||||
SharedLibraryHandlerPtr get(const std::string & dictionary_id);
|
||||
ExternalDictionaryLibraryHandlerPtr get(const String & dictionary_id);
|
||||
|
||||
void create(
|
||||
const std::string & dictionary_id,
|
||||
const std::string & library_path,
|
||||
const std::vector<std::string> & library_settings,
|
||||
const String & dictionary_id,
|
||||
const String & library_path,
|
||||
const std::vector<String> & library_settings,
|
||||
const Block & sample_block,
|
||||
const std::vector<std::string> & attributes_names);
|
||||
const std::vector<String> & attributes_names);
|
||||
|
||||
bool clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id);
|
||||
bool clone(const String & from_dictionary_id, const String & to_dictionary_id);
|
||||
|
||||
bool remove(const std::string & dictionary_id);
|
||||
bool remove(const String & dictionary_id);
|
||||
|
||||
private:
|
||||
/// map: dict_id -> sharedLibraryHandler
|
||||
std::unordered_map<std::string, SharedLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::unordered_map<String, ExternalDictionaryLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
||||
|
@ -27,12 +27,16 @@ std::unique_ptr<HTTPRequestHandler> LibraryBridgeHandlerFactory::createRequestHa
|
||||
{
|
||||
if (uri.getPath() == "/extdict_ping")
|
||||
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
|
||||
else if (uri.getPath() == "/catboost_ping")
|
||||
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
|
||||
}
|
||||
|
||||
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
|
||||
{
|
||||
if (uri.getPath() == "/extdict_request")
|
||||
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
|
||||
else if (uri.getPath() == "/catboost_request")
|
||||
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -1,24 +1,31 @@
|
||||
#include "LibraryBridgeHandlers.h"
|
||||
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
#include "CatBoostLibraryHandlerFactory.h"
|
||||
#include "ExternalDictionaryLibraryHandler.h"
|
||||
#include "ExternalDictionaryLibraryHandlerFactory.h"
|
||||
|
||||
#include <Formats/FormatFactory.h>
|
||||
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <Common/BridgeProtocolVersion.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Poco/Net/HTMLForm.h>
|
||||
#include <Poco/Net/HTTPServerRequest.h>
|
||||
#include <Poco/Net/HTTPServerResponse.h>
|
||||
#include <Poco/Net/HTMLForm.h>
|
||||
#include <Poco/ThreadPool.h>
|
||||
#include <Processors/Formats/IOutputFormat.h>
|
||||
#include <Processors/Formats/IInputFormat.h>
|
||||
#include <QueryPipeline/QueryPipeline.h>
|
||||
#include <Processors/Executors/CompletedPipelineExecutor.h>
|
||||
#include <Processors/Executors/PullingPipelineExecutor.h>
|
||||
#include <Processors/Formats/IInputFormat.h>
|
||||
#include <Processors/Formats/IOutputFormat.h>
|
||||
#include <Processors/Sources/SourceFromSingleChunk.h>
|
||||
#include <QueryPipeline/Pipe.h>
|
||||
#include <QueryPipeline/QueryPipeline.h>
|
||||
#include <Server/HTTP/HTMLForm.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
|
||||
#include <Formats/NativeReader.h>
|
||||
#include <Formats/NativeWriter.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -31,7 +38,7 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
void processError(HTTPServerResponse & response, const std::string & message)
|
||||
void processError(HTTPServerResponse & response, const String & message)
|
||||
{
|
||||
response.setStatusAndReason(HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
|
||||
|
||||
@ -41,7 +48,7 @@ namespace
|
||||
LOG_WARNING(&Poco::Logger::get("LibraryBridge"), fmt::runtime(message));
|
||||
}
|
||||
|
||||
std::shared_ptr<Block> parseColumns(std::string && column_string)
|
||||
std::shared_ptr<Block> parseColumns(String && column_string)
|
||||
{
|
||||
auto sample_block = std::make_shared<Block>();
|
||||
auto names_and_types = NamesAndTypesList::parse(column_string);
|
||||
@ -59,10 +66,10 @@ namespace
|
||||
return ids;
|
||||
}
|
||||
|
||||
std::vector<std::string> parseNamesFromBinary(const std::string & names_string)
|
||||
std::vector<String> parseNamesFromBinary(const String & names_string)
|
||||
{
|
||||
ReadBufferFromString buf(names_string);
|
||||
std::vector<std::string> names;
|
||||
std::vector<String> names;
|
||||
readVectorBinary(names, buf);
|
||||
return names;
|
||||
}
|
||||
@ -79,13 +86,15 @@ static void writeData(Block data, OutputFormatPtr format)
|
||||
executor.execute();
|
||||
}
|
||||
|
||||
|
||||
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
@ -97,7 +106,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
version = 0; /// assumed version for too old servers which do not send a version
|
||||
else
|
||||
{
|
||||
String version_str = params.get("version");
|
||||
const String & version_str = params.get("version");
|
||||
if (!tryParse(version, version_str))
|
||||
{
|
||||
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
|
||||
@ -124,8 +133,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string method = params.get("method");
|
||||
std::string dictionary_id = params.get("dictionary_id");
|
||||
const String & method = params.get("method");
|
||||
const String & dictionary_id = params.get("dictionary_id");
|
||||
|
||||
LOG_TRACE(log, "Library method: '{}', dictionary id: {}", method, dictionary_id);
|
||||
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
|
||||
@ -141,7 +150,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string from_dictionary_id = params.get("from_dictionary_id");
|
||||
const String & from_dictionary_id = params.get("from_dictionary_id");
|
||||
bool cloned = false;
|
||||
cloned = ExternalDictionaryLibraryHandlerFactory::instance().clone(from_dictionary_id, dictionary_id);
|
||||
|
||||
@ -166,7 +175,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string library_path = params.get("library_path");
|
||||
const String & library_path = params.get("library_path");
|
||||
|
||||
if (!params.has("library_settings"))
|
||||
{
|
||||
@ -174,10 +183,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & settings_string = params.get("library_settings");
|
||||
const String & settings_string = params.get("library_settings");
|
||||
|
||||
LOG_DEBUG(log, "Parsing library settings from binary string");
|
||||
std::vector<std::string> library_settings = parseNamesFromBinary(settings_string);
|
||||
std::vector<String> library_settings = parseNamesFromBinary(settings_string);
|
||||
|
||||
/// Needed for library dictionary
|
||||
if (!params.has("attributes_names"))
|
||||
@ -186,10 +195,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & attributes_string = params.get("attributes_names");
|
||||
const String & attributes_string = params.get("attributes_names");
|
||||
|
||||
LOG_DEBUG(log, "Parsing attributes names from binary string");
|
||||
std::vector<std::string> attributes_names = parseNamesFromBinary(attributes_string);
|
||||
std::vector<String> attributes_names = parseNamesFromBinary(attributes_string);
|
||||
|
||||
/// Needed to parse block from binary string format
|
||||
if (!params.has("sample_block"))
|
||||
@ -197,7 +206,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
processError(response, "No 'sample_block' in request URL");
|
||||
return;
|
||||
}
|
||||
std::string sample_block_string = params.get("sample_block");
|
||||
String sample_block_string = params.get("sample_block");
|
||||
|
||||
std::shared_ptr<Block> sample_block;
|
||||
try
|
||||
@ -297,7 +306,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
return;
|
||||
}
|
||||
|
||||
std::string requested_block_string = params.get("requested_block_sample");
|
||||
String requested_block_string = params.get("requested_block_sample");
|
||||
|
||||
std::shared_ptr<Block> requested_sample_block;
|
||||
try
|
||||
@ -332,7 +341,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_WARNING(log, "Unknown library method: '{}'", method);
|
||||
processError(response, "Unknown library method '" + method + "'");
|
||||
LOG_ERROR(log, "Unknown library method: '{}'", method);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
@ -362,6 +372,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
@ -369,6 +380,7 @@ ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExi
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
try
|
||||
@ -382,7 +394,7 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
|
||||
return;
|
||||
}
|
||||
|
||||
std::string dictionary_id = params.get("dictionary_id");
|
||||
const String & dictionary_id = params.get("dictionary_id");
|
||||
|
||||
auto library_handler = ExternalDictionaryLibraryHandlerFactory::instance().get(dictionary_id);
|
||||
|
||||
@ -399,4 +411,199 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
|
||||
}
|
||||
|
||||
|
||||
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
|
||||
size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
, log(&Poco::Logger::get("CatBoostLibraryBridgeRequestHandler"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
HTMLForm params(getContext()->getSettingsRef(), request);
|
||||
|
||||
size_t version;
|
||||
|
||||
if (!params.has("version"))
|
||||
version = 0; /// assumed version for too old servers which do not send a version
|
||||
else
|
||||
{
|
||||
const String & version_str = params.get("version");
|
||||
if (!tryParse(version, version_str))
|
||||
{
|
||||
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (version != LIBRARY_BRIDGE_PROTOCOL_VERSION)
|
||||
{
|
||||
/// backwards compatibility is considered unnecessary for now, just let the user know that the server and the bridge must be upgraded together
|
||||
processError(response, "Server and library-bridge have different versions: '" + std::to_string(version) + "' vs. '" + std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION) + "'");
|
||||
return;
|
||||
}
|
||||
if (!params.has("method"))
|
||||
{
|
||||
processError(response, "No 'method' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & method = params.get("method");
|
||||
|
||||
LOG_TRACE(log, "Library method: '{}'", method);
|
||||
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
|
||||
|
||||
try
|
||||
{
|
||||
if (method == "catboost_GetTreeCount")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
|
||||
if (!params.has("library_path"))
|
||||
{
|
||||
processError(response, "No 'library_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & library_path = params.get("library_path");
|
||||
|
||||
if (!params.has("model_path"))
|
||||
{
|
||||
processError(response, "No 'model_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
CatBoostLibraryHandlerFactory::instance().remove(model_path);
|
||||
|
||||
CatBoostLibraryHandlerFactory::instance().create(library_path, model_path);
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
|
||||
|
||||
if (!catboost_handler)
|
||||
{
|
||||
processError(response, "CatBoost library is not loaded for model " + model_path);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t tree_count = catboost_handler->getTreeCount();
|
||||
writeIntBinary(tree_count, out);
|
||||
}
|
||||
else if (method == "catboost_libEvaluate")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
|
||||
if (!params.has("model_path"))
|
||||
{
|
||||
processError(response, "No 'model_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
if (!params.has("data"))
|
||||
{
|
||||
processError(response, "No 'data' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & data = params.get("data");
|
||||
|
||||
ReadBufferFromString string_read_buf(data);
|
||||
NativeReader deserializer(string_read_buf, /*server_revision*/ 0);
|
||||
Block block_read = deserializer.read();
|
||||
|
||||
Columns col_ptrs = block_read.getColumns();
|
||||
ColumnRawPtrs col_raw_ptrs;
|
||||
for (const auto & p : col_ptrs)
|
||||
col_raw_ptrs.push_back(&*p);
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
|
||||
|
||||
if (!catboost_handler)
|
||||
{
|
||||
processError(response, "CatBoost library is not loaded for model" + model_path);
|
||||
return;
|
||||
}
|
||||
|
||||
ColumnPtr res_col = catboost_handler->evaluate(col_raw_ptrs);
|
||||
|
||||
DataTypePtr res_col_type = std::make_shared<DataTypeFloat64>();
|
||||
String res_col_name = "res_col";
|
||||
|
||||
ColumnsWithTypeAndName res_cols_with_type_and_name = {{res_col, res_col_type, res_col_name}};
|
||||
|
||||
Block block_write(res_cols_with_type_and_name);
|
||||
NativeWriter serializer{out, /*client_revision*/ 0, block_write};
|
||||
serializer.write(block_write);
|
||||
}
|
||||
else
|
||||
{
|
||||
processError(response, "Unknown library method '" + method + "'");
|
||||
LOG_ERROR(log, "Unknown library method: '{}'", method);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
auto message = getCurrentExceptionMessage(true);
|
||||
LOG_ERROR(log, "Failed to process request. Error: {}", message);
|
||||
|
||||
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR, message); // can't call process_error, because of too soon response sending
|
||||
try
|
||||
{
|
||||
writeStringBinary(message, out);
|
||||
out.finalize();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log);
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
out.finalize();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
|
||||
: WithContext(context_)
|
||||
, keep_alive_timeout(keep_alive_timeout_)
|
||||
, log(&Poco::Logger::get("CatBoostLibraryBridgeExistsHandler"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void CatBoostLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
try
|
||||
{
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
HTMLForm params(getContext()->getSettingsRef(), request);
|
||||
|
||||
String res = "1";
|
||||
|
||||
setResponseDefaultHeaders(response, keep_alive_timeout);
|
||||
LOG_TRACE(log, "Sending ping response: {}", res);
|
||||
response.sendBuffer(res.data(), res.size());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException("PingHandler");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Server/HTTP/HTTPRequestHandler.h>
|
||||
#include <Common/logger_useful.h>
|
||||
#include "ExternalDictionaryLibraryHandler.h"
|
||||
#include <mutex>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -26,11 +26,12 @@ public:
|
||||
private:
|
||||
static constexpr inline auto FORMAT = "RowBinary";
|
||||
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
size_t keep_alive_timeout;
|
||||
};
|
||||
|
||||
|
||||
// Handler for checking if the external dictionary library is loaded (used for handshake)
|
||||
class ExternalDictionaryLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
@ -43,4 +44,43 @@ private:
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
|
||||
/// Handler for requests to catboost library. The call protocol is as follows:
|
||||
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and
|
||||
/// a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was
|
||||
/// loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler
|
||||
/// and finally sends the result back to the server.
|
||||
/// Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in
|
||||
/// the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than
|
||||
/// once and if "model.bin" was updated in between.
|
||||
/// (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on.
|
||||
/// Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library
|
||||
/// handler for the given model path is expected to be already loaded by Step (1).
|
||||
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
CatBoostLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
|
||||
|
||||
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
|
||||
|
||||
private:
|
||||
std::mutex mutex;
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
|
||||
// Handler for pinging the library-bridge for catboost access (used for handshake)
|
||||
class CatBoostLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
|
||||
|
||||
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
|
||||
|
||||
private:
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
||||
}
|
||||
|
118
src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp
Normal file
118
src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp
Normal file
@ -0,0 +1,118 @@
|
||||
#include "CatBoostLibraryBridgeHelper.h"
|
||||
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/escapeForFileName.h>
|
||||
#include <Core/Block.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Formats/NativeReader.h>
|
||||
#include <Formats/NativeWriter.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <Poco/Net/HTTPRequest.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(
|
||||
ContextPtr context_,
|
||||
std::string_view library_path_,
|
||||
std::string_view model_path_)
|
||||
: LibraryBridgeHelper(context_->getGlobalContext())
|
||||
, library_path(library_path_)
|
||||
, model_path(model_path_)
|
||||
{
|
||||
}
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const
|
||||
{
|
||||
auto uri = createBaseURI();
|
||||
uri.setPath(PING_HANDLER);
|
||||
return uri;
|
||||
}
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::getMainURI() const
|
||||
{
|
||||
auto uri = createBaseURI();
|
||||
uri.setPath(MAIN_HANDLER);
|
||||
return uri;
|
||||
}
|
||||
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::createRequestURI(const String & method) const
|
||||
{
|
||||
auto uri = getMainURI();
|
||||
uri.addQueryParameter("version", std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION));
|
||||
uri.addQueryParameter("method", method);
|
||||
return uri;
|
||||
}
|
||||
|
||||
bool CatBoostLibraryBridgeHelper::bridgeHandShake()
|
||||
{
|
||||
String result;
|
||||
try
|
||||
{
|
||||
ReadWriteBufferFromHTTP buf(getPingURI(), Poco::Net::HTTPRequest::HTTP_GET, {}, http_timeouts, credentials);
|
||||
readString(result, buf);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result != "1")
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected message from library bridge: {}. Check that bridge and server have the same version.", result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t CatBoostLibraryBridgeHelper::getTreeCount()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_GETTREECOUNT_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[this](std::ostream & os)
|
||||
{
|
||||
os << "library_path=" << escapeForFileName(library_path) << "&";
|
||||
os << "model_path=" << escapeForFileName(model_path);
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
size_t res;
|
||||
readIntBinary(res, buf);
|
||||
return res;
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns)
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
WriteBufferFromOwnString string_write_buf;
|
||||
Block block(columns);
|
||||
NativeWriter serializer(string_write_buf, /*client_revision*/ 0, block);
|
||||
serializer.write(block);
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_LIB_EVALUATE_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[this, serialized = string_write_buf.str()](std::ostream & os)
|
||||
{
|
||||
os << "model_path=" << escapeForFileName(model_path) << "&";
|
||||
os << "data=" << escapeForFileName(serialized);
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
NativeReader deserializer(buf, /*server_revision*/ 0);
|
||||
Block block_read = deserializer.read();
|
||||
|
||||
return block_read.getColumns()[0];
|
||||
}
|
||||
|
||||
}
|
42
src/BridgeHelper/CatBoostLibraryBridgeHelper.h
Normal file
42
src/BridgeHelper/CatBoostLibraryBridgeHelper.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include <BridgeHelper/LibraryBridgeHelper.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
#include <IO/ReadWriteBufferFromHTTP.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Poco/URI.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class CatBoostLibraryBridgeHelper : public LibraryBridgeHelper
|
||||
{
|
||||
public:
|
||||
static constexpr inline auto PING_HANDLER = "/catboost_ping";
|
||||
static constexpr inline auto MAIN_HANDLER = "/catboost_request";
|
||||
|
||||
CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view library_path_, std::string_view model_path_);
|
||||
|
||||
size_t getTreeCount();
|
||||
|
||||
ColumnPtr evaluate(const ColumnsWithTypeAndName & columns);
|
||||
|
||||
protected:
|
||||
Poco::URI getPingURI() const override;
|
||||
|
||||
Poco::URI getMainURI() const override;
|
||||
|
||||
bool bridgeHandShake() override;
|
||||
|
||||
private:
|
||||
static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount";
|
||||
static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate";
|
||||
|
||||
Poco::URI createRequestURI(const String & method) const;
|
||||
|
||||
const String library_path;
|
||||
const String model_path;
|
||||
};
|
||||
|
||||
}
|
@ -12,8 +12,8 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Common base class for XDBC and Library bridge helpers.
|
||||
/// Contains helper methods to check/start bridge sync.
|
||||
/// Base class for server-side bridge helpers, e.g. xdbc-bridge and library-bridge.
|
||||
/// Contains helper methods to check/start bridge sync
|
||||
class IBridgeHelper: protected WithContext
|
||||
{
|
||||
|
||||
|
@ -176,10 +176,10 @@ static void tryLogCurrentExceptionImpl(Poco::Logger * logger, const std::string
|
||||
|
||||
void tryLogCurrentException(const char * log_name, const std::string & start_of_message)
|
||||
{
|
||||
/// Under high memory pressure, any new allocation will definitelly lead
|
||||
/// to MEMORY_LIMIT_EXCEEDED exception.
|
||||
/// Under high memory pressure, new allocations throw a
|
||||
/// MEMORY_LIMIT_EXCEEDED exception.
|
||||
///
|
||||
/// And in this case the exception will not be logged, so let's block the
|
||||
/// In this case the exception will not be logged, so let's block the
|
||||
/// MemoryTracker until the exception will be logged.
|
||||
LockMemoryExceptionInThread lock_memory_tracker(VariableContext::Global);
|
||||
|
||||
@ -189,8 +189,8 @@ void tryLogCurrentException(const char * log_name, const std::string & start_of_
|
||||
|
||||
void tryLogCurrentException(Poco::Logger * logger, const std::string & start_of_message)
|
||||
{
|
||||
/// Under high memory pressure, any new allocation will definitelly lead
|
||||
/// to MEMORY_LIMIT_EXCEEDED exception.
|
||||
/// Under high memory pressure, new allocations throw a
|
||||
/// MEMORY_LIMIT_EXCEEDED exception.
|
||||
///
|
||||
/// And in this case the exception will not be logged, so let's block the
|
||||
/// MemoryTracker until the exception will be logged.
|
||||
|
182
src/Functions/catboostEvaluate.cpp
Normal file
182
src/Functions/catboostEvaluate.cpp
Normal file
@ -0,0 +1,182 @@
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
|
||||
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
|
||||
#include <BridgeHelper/IBridgeHelper.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/Context_fwd.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int FILE_DOESNT_EXIST;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
}
|
||||
|
||||
/// Evaluate CatBoost model.
|
||||
/// - Arguments: float features first, then categorical features.
|
||||
/// - Result: Float64.
|
||||
class FunctionCatBoostEvaluate final : public IFunction, WithContext
|
||||
{
|
||||
private:
|
||||
mutable std::unique_ptr<CatBoostLibraryBridgeHelper> bridge_helper;
|
||||
|
||||
public:
|
||||
static constexpr auto name = "catboostEvaluate";
|
||||
|
||||
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionCatBoostEvaluate>(context_); }
|
||||
|
||||
explicit FunctionCatBoostEvaluate(ContextPtr context_) : WithContext(context_) {}
|
||||
String getName() const override { return name; }
|
||||
bool isVariadic() const override { return true; }
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
|
||||
bool isDeterministic() const override { return false; }
|
||||
bool useDefaultImplementationForNulls() const override { return false; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
|
||||
void initBridge(const ColumnConst * name_col) const
|
||||
{
|
||||
String library_path = getContext()->getConfigRef().getString("catboost_lib_path");
|
||||
if (!std::filesystem::exists(library_path))
|
||||
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load library {}: file doesn't exist", library_path);
|
||||
|
||||
String model_path = name_col->getValue<String>();
|
||||
if (!std::filesystem::exists(model_path))
|
||||
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load model {}: file doesn't exist", model_path);
|
||||
|
||||
bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), library_path, model_path);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeFromLibraryBridge() const
|
||||
{
|
||||
size_t tree_count = bridge_helper->getTreeCount();
|
||||
auto type = std::make_shared<DataTypeFloat64>();
|
||||
if (tree_count == 1)
|
||||
return type;
|
||||
|
||||
DataTypes types(tree_count, type);
|
||||
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} expects at least 2 arguments", getName());
|
||||
|
||||
if (!isString(arguments[0].type))
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of first argument of function {}, expected a string.", arguments[0].type->getName(), getName());
|
||||
|
||||
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
|
||||
if (!name_col)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
|
||||
|
||||
initBridge(name_col);
|
||||
|
||||
auto type = getReturnTypeFromLibraryBridge();
|
||||
|
||||
bool has_nullable = false;
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
has_nullable = has_nullable || arguments[i].type->isNullable();
|
||||
|
||||
if (has_nullable)
|
||||
{
|
||||
if (const auto * tuple = typeid_cast<const DataTypeTuple *>(type.get()))
|
||||
{
|
||||
auto elements = tuple->getElements();
|
||||
for (auto & element : elements)
|
||||
element = makeNullable(element);
|
||||
|
||||
type = std::make_shared<DataTypeTuple>(elements);
|
||||
}
|
||||
else
|
||||
type = makeNullable(type);
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
|
||||
{
|
||||
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
|
||||
if (!name_col)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
|
||||
|
||||
ColumnRawPtrs column_ptrs;
|
||||
Columns materialized_columns;
|
||||
ColumnPtr null_map;
|
||||
|
||||
ColumnsWithTypeAndName feature_arguments(arguments.begin() + 1, arguments.end());
|
||||
for (auto & arg : feature_arguments)
|
||||
{
|
||||
if (auto full_column = arg.column->convertToFullColumnIfConst())
|
||||
{
|
||||
materialized_columns.push_back(full_column);
|
||||
arg.column = full_column;
|
||||
}
|
||||
if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(&*arg.column))
|
||||
{
|
||||
if (!null_map)
|
||||
null_map = col_nullable->getNullMapColumnPtr();
|
||||
else
|
||||
{
|
||||
auto mut_null_map = IColumn::mutate(std::move(null_map));
|
||||
|
||||
NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mut_null_map).getData();
|
||||
const NullMap & src_null_map = col_nullable->getNullMapColumn().getData();
|
||||
|
||||
for (size_t i = 0, size = result_null_map.size(); i < size; ++i)
|
||||
if (src_null_map[i])
|
||||
result_null_map[i] = 1;
|
||||
|
||||
null_map = std::move(mut_null_map);
|
||||
}
|
||||
|
||||
arg.column = col_nullable->getNestedColumn().getPtr();
|
||||
arg.type = static_cast<const DataTypeNullable &>(*arg.type).getNestedType();
|
||||
}
|
||||
}
|
||||
|
||||
auto res = bridge_helper->evaluate(feature_arguments);
|
||||
|
||||
if (null_map)
|
||||
{
|
||||
if (const auto * tuple = typeid_cast<const ColumnTuple *>(res.get()))
|
||||
{
|
||||
auto nested = tuple->getColumns();
|
||||
for (auto & col : nested)
|
||||
col = ColumnNullable::create(col, null_map);
|
||||
|
||||
res = ColumnTuple::create(nested);
|
||||
}
|
||||
else
|
||||
res = ColumnNullable::create(res, null_map);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
REGISTER_FUNCTION(CatBoostEvaluate)
|
||||
{
|
||||
factory.registerFunction<FunctionCatBoostEvaluate>();
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
<clickhouse>
|
||||
<catboost_lib_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_lib_path>
|
||||
</clickhouse>
|
BIN
tests/integration/test_catboost_evaluate/model/amazon_model.bin
Normal file
BIN
tests/integration/test_catboost_evaluate/model/amazon_model.bin
Normal file
Binary file not shown.
BIN
tests/integration/test_catboost_evaluate/model/libcatboostmodel.so
Executable file
BIN
tests/integration/test_catboost_evaluate/model/libcatboostmodel.so
Executable file
Binary file not shown.
BIN
tests/integration/test_catboost_evaluate/model/simple_model.bin
Normal file
BIN
tests/integration/test_catboost_evaluate/model/simple_model.bin
Normal file
Binary file not shown.
318
tests/integration/test_catboost_evaluate/test.py
Normal file
318
tests/integration/test_catboost_evaluate/test.py
Normal file
@ -0,0 +1,318 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
from helpers.cluster import ClickHouseCluster
|
||||
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
|
||||
instance = cluster.add_instance(
|
||||
"instance", stay_alive=True, main_configs=["config/models_config.xml"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ch_cluster():
|
||||
try:
|
||||
cluster.start()
|
||||
|
||||
os.system(
|
||||
"docker cp {local} {cont_id}:{dist}".format(
|
||||
local=os.path.join(SCRIPT_DIR, "model/."),
|
||||
cont_id=instance.docker_id,
|
||||
dist="/etc/clickhouse-server/model",
|
||||
)
|
||||
)
|
||||
instance.restart_clickhouse()
|
||||
|
||||
yield cluster
|
||||
|
||||
finally:
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# simple_model.bin has 2 float features and 9 categorical features
|
||||
|
||||
|
||||
def testConstantFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def testNonConstantFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
instance.query(
|
||||
"CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;"
|
||||
)
|
||||
instance.query("INSERT INTO T VALUES(0, 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11) from T;"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
|
||||
|
||||
def testModelPathIsNotAConstString(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert (
|
||||
"Illegal type UInt8 of first argument of function catboostEvaluate, expected a string"
|
||||
in err
|
||||
)
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
instance.query("CREATE TABLE T(ID UInt32, A String) ENGINE MergeTree ORDER BY ID")
|
||||
instance.query("INSERT INTO T VALUES(0, 'test');")
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate(A, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) FROM T;"
|
||||
)
|
||||
assert (
|
||||
"First argument of function catboostEvaluate must be a constant string" in err
|
||||
)
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
|
||||
|
||||
def testWrongNumberOfFeatureArguments(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');"
|
||||
)
|
||||
assert "Function catboostEvaluate expects at least 2 arguments" in err
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2);"
|
||||
)
|
||||
assert (
|
||||
"Number of columns is different with number of features: columns size 2 float features size 2 + cat features size 9"
|
||||
in err
|
||||
)
|
||||
|
||||
|
||||
def testFloatFeatureMustBeNumeric(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert "Column 1 should be numeric to make float feature" in err
|
||||
|
||||
|
||||
def testCategoricalFeatureMustBeNumericOrString(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);"
|
||||
)
|
||||
assert "Column 7 should be numeric or string" in err
|
||||
|
||||
|
||||
def testOnLowCardinalityFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
# same but on domain-compressed data
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def testOnNullableFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
# Actual NULLs are disallowed
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL));"
|
||||
)
|
||||
assert "Column 0 should be numeric to make float feature" in err
|
||||
|
||||
|
||||
def testInvalidLibraryPath(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
# temporarily move library elsewhere
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/libcatboostmodel.so /etc/clickhouse-server/model/nonexistant.so",
|
||||
]
|
||||
)
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert (
|
||||
"Can't load library /etc/clickhouse-server/model/libcatboostmodel.so: file doesn't exist"
|
||||
in err
|
||||
)
|
||||
|
||||
# restore
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/nonexistant.so /etc/clickhouse-server/model/libcatboostmodel.so",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def testInvalidModelPath(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert "Can't load model : file doesn't exist" in err
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('model_non_existant.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert "Can't load model model_non_existant.bin: file doesn't exist" in err
|
||||
|
||||
|
||||
def testRecoveryAfterCrash(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
instance.exec_in_container(
|
||||
["bash", "-c", "kill -9 `pidof clickhouse-library-bridge`"], user="root"
|
||||
)
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# amazon_model.bin has 0 float features and 9 categorical features
|
||||
|
||||
|
||||
def testAmazonModelSingleRow(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
expected = "0.7774665009089274\n"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def testAmazonModelManyRows(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("drop table if exists amazon")
|
||||
|
||||
result = instance.query(
|
||||
"create table amazon ( DATE Date materialized today(), ACTION UInt8, RESOURCE UInt32, MGR_ID UInt32, ROLE_ROLLUP_1 UInt32, ROLE_ROLLUP_2 UInt32, ROLE_DEPTNAME UInt32, ROLE_TITLE UInt32, ROLE_FAMILY_DESC UInt32, ROLE_FAMILY UInt32, ROLE_CODE UInt32) engine = MergeTree order by DATE"
|
||||
)
|
||||
|
||||
result = instance.query(
|
||||
"insert into amazon select number % 256, number, number, number, number, number, number, number, number, number from numbers(7500)"
|
||||
)
|
||||
|
||||
# First compute prediction, then as a very crude way to fingerprint and compare the result: sum and floor
|
||||
# (the focus is to test that the exchange of large result sets between the server and the bridge works)
|
||||
result = instance.query(
|
||||
"SELECT floor(sum(catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', RESOURCE, MGR_ID, ROLE_ROLLUP_1, ROLE_ROLLUP_2, ROLE_DEPTNAME, ROLE_TITLE, ROLE_FAMILY_DESC, ROLE_FAMILY, ROLE_CODE))) FROM amazon"
|
||||
)
|
||||
|
||||
expected = "5834\n"
|
||||
assert result == expected
|
||||
|
||||
result = instance.query("drop table if exists amazon")
|
||||
|
||||
|
||||
def testModelUpdate(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
|
||||
result = instance.query(query)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
# simulate an update of the model: temporarily move the amazon model in place of the simple model
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/simple_model.bin.bak",
|
||||
]
|
||||
)
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/amazon_model.bin /etc/clickhouse-server/model/simple_model.bin",
|
||||
]
|
||||
)
|
||||
|
||||
# since the amazon model has a different number of features than the simple model, we should get an error
|
||||
err = instance.query_and_get_error(query)
|
||||
assert (
|
||||
"Number of columns is different with number of features: columns size 11 float features size 0 + cat features size 9"
|
||||
in err
|
||||
)
|
||||
|
||||
# restore
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/amazon_model.bin",
|
||||
]
|
||||
)
|
||||
instance.exec_in_container(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin",
|
||||
]
|
||||
)
|
Loading…
Reference in New Issue
Block a user