feat: implement catboost in library-bridge

This commit moves the catboost model evaluation out of the server
process into the library-bridge binary. This serves two goals: On the
one hand, crashes / memory corruptions of the catboost library no longer
affect the server. On the other hand, we can forbid loading dynamic
libraries in the server (catboost was the last consumer of this
functionality), thus improving security.

SQL syntax:

  SELECT
    catboostEvaluate('/path/to/model.bin', FEAT_1, ..., FEAT_N) > 0 AS prediction,
    ACTION AS target
  FROM amazon_train
  LIMIT 10

Required configuration:

  <catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path>

*** Implementation Details ***

The internal protocol between the server and the library-bridge is
simple:

- HTTP GET on path "/extdict_ping":
  A ping, used during the handshake to check if the library-bridge runs.

- HTTP POST on path "extdict_request"
  (1) Send a "catboost_GetTreeCount" request from the server to the
      bridge, containing a library path (e.g /home/user/libcatboost.so) and
      a model path (e.g. /home/user/model.bin). Rirst, this unloads the
      catboost library handler associated to the model path (if it was
      loaded), then loads the catboost library handler associated to the
      model path, then executes GetTreeCount() on the library handler and
      finally sends the result back to the server. Step (1) is called once
      by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The
      library path handler is unloaded in the beginning because it contains
      state which may no longer be valid if the user runs
      catboost("/path/to/model.bin", ...) more than once and if "model.bin"
      was updated in between.
  (2) Send "catboost_Evaluate" from the server to the bridge, containing
      the model path and the features to run the interference on. Step (2)
      is called multiple times (once per chunk) by the server from function
      FunctionCatBoostEvaluate::executeImpl(). The library handler for the
      given model path is expected to be already loaded by Step (1).

Fixes #27870
This commit is contained in:
Robert Schulze 2022-08-05 07:53:06 +00:00
parent 68808858a5
commit 60f9f6855d
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
26 changed files with 1588 additions and 59 deletions

4
.gitignore vendored
View File

@ -58,6 +58,10 @@ cmake_install.cmake
CTestTestfile.cmake
*.a
*.o
*.so
*.dll
*.lib
*.dylib
cmake-build-*
# Python cache

View File

@ -1823,6 +1823,36 @@ Result:
Evaluate external model.
Accepts a model name and model arguments. Returns Float64.
## catboostEvaluate(path_to_model, feature_1, feature_2, …, feature_n)
Evaluate external catboost model. [CatBoost](https://catboost.ai) is an open-source gradient boosting library developed by Yandex for machine learing.
Accepts a path to a catboost model and model arguments (features). Returns Float64.
``` sql
SELECT feat1, ..., feat_n, catboostEvaluate('/path/to/model.bin', feat_1, ..., feat_n) AS prediction
FROM data_table
```
**Prerequisites**
1. Build the catboost evaluation library
Before evaluating catboost models, the `libcatboostmodel.<so|dylib>` library must be made available. See [CatBoost documentation](https://catboost.ai/docs/concepts/c-plus-plus-api_dynamic-c-pluplus-wrapper.html) how to compile it.
Next, specify the path to `libcatboostmodel.<so|dylib>` in the clickhouse configuration:
``` xml
<clickhouse>
...
<catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path>
...
</clickhouse>
```
2. Train a catboost model using libcatboost
See [Training and applying models](https://catboost.ai/docs/features/training.html#training) for how to train catboost models from a training data set.
## throwIf(x\[, message\[, error_code\]\])
Throw an exception if the argument is non zero.

View File

@ -54,7 +54,7 @@ else ()
endif ()
if (NOT USE_MUSL)
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to Library dictionary source" ${ENABLE_CLICKHOUSE_ALL})
option (ENABLE_CLICKHOUSE_LIBRARY_BRIDGE "HTTP-server working like a proxy to external dynamically loaded libraries" ${ENABLE_CLICKHOUSE_ALL})
endif ()
# https://presentations.clickhouse.com/matemarketing_2020/

View File

@ -1,6 +1,8 @@
include(${ClickHouse_SOURCE_DIR}/cmake/split_debug_symbols.cmake)
set (CLICKHOUSE_LIBRARY_BRIDGE_SOURCES
CatBoostLibraryHandler.cpp
CatBoostLibraryHandlerFactory.cpp
ExternalDictionaryLibraryAPI.cpp
ExternalDictionaryLibraryHandler.cpp
ExternalDictionaryLibraryHandlerFactory.cpp

View File

@ -0,0 +1,49 @@
#pragma once
#include <cstdint>
#include <cstddef>
// Function pointer typedefs and names of libcatboost.so functions used by ClickHouse
struct CatBoostLibraryAPI
{
using ModelCalcerHandle = void;
using ModelCalcerCreateFunc = ModelCalcerHandle * (*)();
static constexpr const char * ModelCalcerCreateName = "ModelCalcerCreate";
using ModelCalcerDeleteFunc = void (*)(ModelCalcerHandle *);
static constexpr const char * ModelCalcerDeleteName = "ModelCalcerDelete";
using GetErrorStringFunc = const char * (*)();
static constexpr const char * GetErrorStringName = "GetErrorString";
using LoadFullModelFromFileFunc = bool (*)(ModelCalcerHandle *, const char *);
static constexpr const char * LoadFullModelFromFileName = "LoadFullModelFromFile";
using CalcModelPredictionFlatFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, double *, size_t);
static constexpr const char * CalcModelPredictionFlatName = "CalcModelPredictionFlat";
using CalcModelPredictionFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const char ***, size_t, double *, size_t);
static constexpr const char * CalcModelPredictionName = "CalcModelPrediction";
using CalcModelPredictionWithHashedCatFeaturesFunc = bool (*)(ModelCalcerHandle *, size_t, const float **, size_t, const int **, size_t, double *, size_t);
static constexpr const char * CalcModelPredictionWithHashedCatFeaturesName = "CalcModelPredictionWithHashedCatFeatures";
using GetStringCatFeatureHashFunc = int (*)(const char *, size_t);
static constexpr const char * GetStringCatFeatureHashName = "GetStringCatFeatureHash";
using GetIntegerCatFeatureHashFunc = int (*)(uint64_t);
static constexpr const char * GetIntegerCatFeatureHashName = "GetIntegerCatFeatureHash";
using GetFloatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetFloatFeaturesCountName = "GetFloatFeaturesCount";
using GetCatFeaturesCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetCatFeaturesCountName = "GetCatFeaturesCount";
using GetTreeCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetTreeCountName = "GetTreeCount";
using GetDimensionsCountFunc = size_t (*)(ModelCalcerHandle *);
static constexpr const char * GetDimensionsCountName = "GetDimensionsCount";
};

View File

@ -0,0 +1,376 @@
#include "CatBoostLibraryHandler.h"
#include <Columns/ColumnTuple.h>
#include <Common/FieldVisitorConvertToNumber.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int CANNOT_APPLY_CATBOOST_MODEL;
extern const int CANNOT_LOAD_CATBOOST_MODEL;
extern const int LOGICAL_ERROR;
}
CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib)
{
ModelCalcerCreate = lib.get<CatBoostLibraryAPI::ModelCalcerCreateFunc>(CatBoostLibraryAPI::ModelCalcerCreateName);
ModelCalcerDelete = lib.get<CatBoostLibraryAPI::ModelCalcerDeleteFunc>(CatBoostLibraryAPI::ModelCalcerDeleteName);
GetErrorString = lib.get<CatBoostLibraryAPI::GetErrorStringFunc>(CatBoostLibraryAPI::GetErrorStringName);
LoadFullModelFromFile = lib.get<CatBoostLibraryAPI::LoadFullModelFromFileFunc>(CatBoostLibraryAPI::LoadFullModelFromFileName);
CalcModelPredictionFlat = lib.get<CatBoostLibraryAPI::CalcModelPredictionFlatFunc>(CatBoostLibraryAPI::CalcModelPredictionFlatName);
CalcModelPrediction = lib.get<CatBoostLibraryAPI::CalcModelPredictionFunc>(CatBoostLibraryAPI::CalcModelPredictionName);
CalcModelPredictionWithHashedCatFeatures = lib.get<CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc>(CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesName);
GetStringCatFeatureHash = lib.get<CatBoostLibraryAPI::GetStringCatFeatureHashFunc>(CatBoostLibraryAPI::GetStringCatFeatureHashName);
GetIntegerCatFeatureHash = lib.get<CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc>(CatBoostLibraryAPI::GetIntegerCatFeatureHashName);
GetFloatFeaturesCount = lib.get<CatBoostLibraryAPI::GetFloatFeaturesCountFunc>(CatBoostLibraryAPI::GetFloatFeaturesCountName);
GetCatFeaturesCount = lib.get<CatBoostLibraryAPI::GetCatFeaturesCountFunc>(CatBoostLibraryAPI::GetCatFeaturesCountName);
GetTreeCount = lib.tryGet<CatBoostLibraryAPI::GetTreeCountFunc>(CatBoostLibraryAPI::GetTreeCountName);
GetDimensionsCount = lib.tryGet<CatBoostLibraryAPI::GetDimensionsCountFunc>(CatBoostLibraryAPI::GetDimensionsCountName);
}
CatBoostLibraryHandler::CatBoostLibraryHandler(
const std::string & library_path,
const std::string & model_path)
: library(std::make_shared<SharedLibrary>(library_path))
, api(*library)
{
model_calcer_handle = api.ModelCalcerCreate();
if (!api.LoadFullModelFromFile(model_calcer_handle, model_path.c_str()))
{
throw Exception(ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL,
"Cannot load CatBoost model: {}", api.GetErrorString());
}
float_features_count = api.GetFloatFeaturesCount(model_calcer_handle);
cat_features_count = api.GetCatFeaturesCount(model_calcer_handle);
tree_count = 1;
if (api.GetDimensionsCount)
tree_count = api.GetDimensionsCount(model_calcer_handle);
}
CatBoostLibraryHandler::~CatBoostLibraryHandler()
{
api.ModelCalcerDelete(model_calcer_handle);
}
namespace
{
/// Buffer should be allocated with features_count * column->size() elements.
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
template <typename T>
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count)
{
size_t size = column->size();
FieldVisitorConvertToNumber<T> visitor;
for (size_t i = 0; i < size; ++i)
{
/// TODO: Replace with column visitor.
Field field;
column->get(i, field);
*buffer = applyVisitor(visitor, field);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count)
{
size_t size = column.size();
for (size_t i = 0; i < size; ++i)
{
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
PODArray<char> placeFixedStringColumn(const ColumnFixedString & column, const char ** buffer, size_t features_count)
{
size_t size = column.size();
size_t str_size = column.getN();
PODArray<char> data(size * (str_size + 1));
char * data_ptr = data.data();
for (size_t i = 0; i < size; ++i)
{
auto ref = column.getDataAt(i);
memcpy(data_ptr, ref.data, ref.size);
data_ptr[ref.size] = 0;
*buffer = data_ptr;
data_ptr += ref.size + 1;
buffer += features_count;
}
return data;
}
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
template <typename T>
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const T** buffer)
{
if (size == 0)
return nullptr;
size_t column_size = columns[offset]->size();
auto data_column = ColumnVector<T>::create(size * column_size);
T * data = data_column->getData().data();
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (column->isNumeric())
placeColumnAsNumber(column, data + i, size);
}
for (size_t i = 0; i < column_size; ++i)
{
*buffer = data;
++buffer;
data += size;
}
return data_column;
}
/// Place columns into buffer, returns data which was used for fixed string columns.
/// Buffer should contains column->size() values, each value contains size strings.
std::vector<PODArray<char>> placeStringColumns(const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer)
{
if (size == 0)
return {};
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
placeStringColumn(*column_string, buffer + i, size);
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
else
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
}
return data;
}
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
void fillCatFeaturesBuffer(
const char *** cat_features, const char ** buffer,
size_t column_size, size_t cat_features_count)
{
for (size_t i = 0; i < column_size; ++i)
{
*cat_features = buffer;
++cat_features;
buffer += cat_features_count;
}
}
/// Calc hash for string cat feature at ps positions.
template <typename Column>
void calcStringHashes(const Column * column, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
{
size_t column_size = column->size();
for (size_t j = 0; j < column_size; ++j)
{
auto ref = column->getDataAt(j);
const_cast<int *>(*buffer)[ps] = api.GetStringCatFeatureHash(ref.data, ref.size);
++buffer;
}
}
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
{
for (size_t j = 0; j < column_size; ++j)
{
const_cast<int *>(*buffer)[ps] = api.GetIntegerCatFeatureHash((*buffer)[ps]);
++buffer;
}
}
/// buffer contains column->size() rows and size columns.
/// For int cat features calc hash inplace.
/// For string cat features calc hash from column rows.
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer, const CatBoostLibraryHandler::APIHolder & api)
{
if (size == 0)
return;
size_t column_size = columns[offset]->size();
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
const auto * column = columns[offset + i];
if (const auto * column_string = typeid_cast<const ColumnString *>(column))
calcStringHashes(column_string, i, buffer, api);
else if (const auto * column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
calcStringHashes(column_fixed_string, i, buffer, api);
else
calcIntHashes(column_size, i, buffer, api);
}
}
}
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
/// * CalcModelPredictionFlat if no cat features
/// * CalcModelPrediction if all cat features are strings
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
ColumnFloat64::MutablePtr CatBoostLibraryHandler::evalImpl(
const ColumnRawPtrs & columns,
bool cat_features_are_strings) const
{
std::string error_msg = "Error occurred while applying CatBoost model: ";
size_t column_size = columns.front()->size();
auto result = ColumnFloat64::create(column_size * tree_count);
auto * result_buf = result->getData().data();
if (!column_size)
return result;
/// Prepare float features.
PODArray<const float *> float_features(column_size);
auto * float_features_buf = float_features.data();
/// Store all float data into single column. float_features is a list of pointers to it.
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
if (cat_features_count == 0)
{
if (!api.CalcModelPredictionFlat(model_calcer_handle, column_size,
float_features_buf, float_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
return result;
}
/// Prepare cat features.
if (cat_features_are_strings)
{
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
PODArray<const char **> cat_features(column_size);
auto * cat_features_buf = cat_features.data();
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
cat_features_count, cat_features_holder.data());
if (!api.CalcModelPrediction(model_calcer_handle, column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
else
{
PODArray<const int *> cat_features(column_size);
auto * cat_features_buf = cat_features.data();
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
cat_features_count, cat_features_buf);
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf, api);
if (!api.CalcModelPredictionWithHashedCatFeatures(
model_calcer_handle, column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size * tree_count))
{
throw Exception(error_msg + api.GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
return result;
}
size_t CatBoostLibraryHandler::getTreeCount() const
{
std::lock_guard lock(mutex);
return tree_count;
}
ColumnPtr CatBoostLibraryHandler::evaluate(const ColumnRawPtrs & columns) const
{
std::lock_guard lock(mutex);
if (columns.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got empty columns list for CatBoost model.");
if (columns.size() != float_features_count + cat_features_count)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Number of columns is different with number of features: columns size {} float features size {} + cat features size {}",
columns.size(),
float_features_count,
cat_features_count);
for (size_t i = 0; i < float_features_count; ++i)
{
if (!columns[i]->isNumeric())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric to make float feature.", i);
}
}
bool cat_features_are_strings = true;
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
{
const auto * column = columns[i];
if (column->isNumeric())
{
cat_features_are_strings = false;
}
else if (!(typeid_cast<const ColumnString *>(column)
|| typeid_cast<const ColumnFixedString *>(column)))
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Column {} should be numeric or string.", i);
}
}
auto result = evalImpl(columns, cat_features_are_strings);
if (tree_count == 1)
return result;
size_t column_size = columns.front()->size();
auto * result_buf = result->getData().data();
/// Multiple trees case. Copy data to several columns.
MutableColumns mutable_columns(tree_count);
std::vector<Float64 *> column_ptrs(tree_count);
for (size_t i = 0; i < tree_count; ++i)
{
auto col = ColumnFloat64::create(column_size);
column_ptrs[i] = col->getData().data();
mutable_columns[i] = std::move(col);
}
Float64 * data = result_buf;
for (size_t row = 0; row < column_size; ++row)
{
for (size_t i = 0; i < tree_count; ++i)
{
*column_ptrs[i] = *data;
++column_ptrs[i];
++data;
}
}
return ColumnTuple::create(std::move(mutable_columns));
}
}

View File

@ -0,0 +1,71 @@
#pragma once
#include "CatBoostLibraryAPI.h"
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <Common/SharedLibrary.h>
#include <base/defines.h>
#include <mutex>
namespace DB
{
/// Abstracts access to the CatBoost shared library.
class CatBoostLibraryHandler
{
public:
/// Holds pointers to CatBoost library functions
struct APIHolder
{
explicit APIHolder(SharedLibrary & lib);
// NOLINTBEGIN(readability-identifier-naming)
CatBoostLibraryAPI::ModelCalcerCreateFunc ModelCalcerCreate;
CatBoostLibraryAPI::ModelCalcerDeleteFunc ModelCalcerDelete;
CatBoostLibraryAPI::GetErrorStringFunc GetErrorString;
CatBoostLibraryAPI::LoadFullModelFromFileFunc LoadFullModelFromFile;
CatBoostLibraryAPI::CalcModelPredictionFlatFunc CalcModelPredictionFlat;
CatBoostLibraryAPI::CalcModelPredictionFunc CalcModelPrediction;
CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc CalcModelPredictionWithHashedCatFeatures;
CatBoostLibraryAPI::GetStringCatFeatureHashFunc GetStringCatFeatureHash;
CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc GetIntegerCatFeatureHash;
CatBoostLibraryAPI::GetFloatFeaturesCountFunc GetFloatFeaturesCount;
CatBoostLibraryAPI::GetCatFeaturesCountFunc GetCatFeaturesCount;
CatBoostLibraryAPI::GetTreeCountFunc GetTreeCount;
CatBoostLibraryAPI::GetDimensionsCountFunc GetDimensionsCount;
// NOLINTEND(readability-identifier-naming)
};
CatBoostLibraryHandler(
const std::string & library_path,
const std::string & model_path);
~CatBoostLibraryHandler();
size_t getTreeCount() const;
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
private:
const SharedLibraryPtr library;
const APIHolder api;
mutable std::mutex mutex;
CatBoostLibraryAPI::ModelCalcerHandle * model_calcer_handle TSA_GUARDED_BY(mutex) TSA_PT_GUARDED_BY(mutex);
size_t float_features_count TSA_GUARDED_BY(mutex);
size_t cat_features_count TSA_GUARDED_BY(mutex);
size_t tree_count TSA_GUARDED_BY(mutex);
ColumnFloat64::MutablePtr evalImpl(const ColumnRawPtrs & columns, bool cat_features_are_strings) const TSA_REQUIRES(mutex);
};
using CatBoostLibraryHandlerPtr = std::shared_ptr<CatBoostLibraryHandler>;
}

View File

@ -0,0 +1,49 @@
#include "CatBoostLibraryHandlerFactory.h"
#include <Common/logger_useful.h>
namespace DB
{
CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance()
{
static CatBoostLibraryHandlerFactory instance;
return instance;
}
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::get(const String & model_path)
{
std::lock_guard lock(mutex);
if (auto handler = library_handlers.find(model_path); handler != library_handlers.end())
return handler->second;
return nullptr;
}
void CatBoostLibraryHandlerFactory::create(const String & library_path, const String & model_path)
{
std::lock_guard lock(mutex);
if (library_handlers.contains(model_path))
{
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot load catboost library handler for model path {} because it exists already", model_path);
return;
}
library_handlers.emplace(std::make_pair(model_path, std::make_shared<CatBoostLibraryHandler>(library_path, model_path)));
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path {}.", model_path);
}
void CatBoostLibraryHandlerFactory::remove(const String & model_path)
{
std::lock_guard lock(mutex);
bool deleted = library_handlers.erase(model_path);
if (!deleted)
{
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path: {}", model_path);
return;
}
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path: {}", model_path);
}
}

View File

@ -0,0 +1,31 @@
#pragma once
#include "CatBoostLibraryHandler.h"
#include <base/defines.h>
#include <mutex>
#include <unordered_map>
namespace DB
{
class CatBoostLibraryHandlerFactory final : private boost::noncopyable
{
public:
static CatBoostLibraryHandlerFactory & instance();
CatBoostLibraryHandlerPtr get(const String & model_path);
void create(const String & library_path, const String & model_path);
void remove(const String & model_path);
private:
/// map: model path -> shared library handler
std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::mutex mutex;
};
}

View File

@ -50,6 +50,6 @@ private:
void * lib_data;
};
using SharedLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
using ExternalDictionaryLibraryHandlerPtr = std::shared_ptr<ExternalDictionaryLibraryHandler>;
}

View File

@ -1,37 +1,40 @@
#include "ExternalDictionaryLibraryHandlerFactory.h"
#include <Common/logger_useful.h>
namespace DB
{
SharedLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const std::string & dictionary_id)
ExternalDictionaryLibraryHandlerPtr ExternalDictionaryLibraryHandlerFactory::get(const String & dictionary_id)
{
std::lock_guard lock(mutex);
auto library_handler = library_handlers.find(dictionary_id);
if (library_handler != library_handlers.end())
return library_handler->second;
if (auto handler = library_handlers.find(dictionary_id); handler != library_handlers.end())
return handler->second;
return nullptr;
}
void ExternalDictionaryLibraryHandlerFactory::create(
const std::string & dictionary_id,
const std::string & library_path,
const std::vector<std::string> & library_settings,
const String & dictionary_id,
const String & library_path,
const std::vector<String> & library_settings,
const Block & sample_block,
const std::vector<std::string> & attributes_names)
const std::vector<String> & attributes_names)
{
std::lock_guard lock(mutex);
if (!library_handlers.contains(dictionary_id))
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
else
if (library_handlers.contains(dictionary_id))
{
LOG_WARNING(&Poco::Logger::get("ExternalDictionaryLibraryHandlerFactory"), "Library handler with dictionary id {} already exists", dictionary_id);
return;
}
library_handlers.emplace(std::make_pair(dictionary_id, std::make_shared<ExternalDictionaryLibraryHandler>(library_path, library_settings, sample_block, attributes_names)));
}
bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id)
bool ExternalDictionaryLibraryHandlerFactory::clone(const String & from_dictionary_id, const String & to_dictionary_id)
{
std::lock_guard lock(mutex);
auto from_library_handler = library_handlers.find(from_dictionary_id);
@ -45,7 +48,7 @@ bool ExternalDictionaryLibraryHandlerFactory::clone(const std::string & from_dic
}
bool ExternalDictionaryLibraryHandlerFactory::remove(const std::string & dictionary_id)
bool ExternalDictionaryLibraryHandlerFactory::remove(const String & dictionary_id)
{
std::lock_guard lock(mutex);
/// extDict_libDelete is called in destructor.

View File

@ -17,22 +17,22 @@ class ExternalDictionaryLibraryHandlerFactory final : private boost::noncopyable
public:
static ExternalDictionaryLibraryHandlerFactory & instance();
SharedLibraryHandlerPtr get(const std::string & dictionary_id);
ExternalDictionaryLibraryHandlerPtr get(const String & dictionary_id);
void create(
const std::string & dictionary_id,
const std::string & library_path,
const std::vector<std::string> & library_settings,
const String & dictionary_id,
const String & library_path,
const std::vector<String> & library_settings,
const Block & sample_block,
const std::vector<std::string> & attributes_names);
const std::vector<String> & attributes_names);
bool clone(const std::string & from_dictionary_id, const std::string & to_dictionary_id);
bool clone(const String & from_dictionary_id, const String & to_dictionary_id);
bool remove(const std::string & dictionary_id);
bool remove(const String & dictionary_id);
private:
/// map: dict_id -> sharedLibraryHandler
std::unordered_map<std::string, SharedLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::unordered_map<String, ExternalDictionaryLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::mutex mutex;
};

View File

@ -27,12 +27,16 @@ std::unique_ptr<HTTPRequestHandler> LibraryBridgeHandlerFactory::createRequestHa
{
if (uri.getPath() == "/extdict_ping")
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
else if (uri.getPath() == "/catboost_ping")
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
}
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
{
if (uri.getPath() == "/extdict_request")
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
else if (uri.getPath() == "/catboost_request")
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
}
return nullptr;

View File

@ -1,24 +1,31 @@
#include "LibraryBridgeHandlers.h"
#include "CatBoostLibraryHandler.h"
#include "CatBoostLibraryHandlerFactory.h"
#include "ExternalDictionaryLibraryHandler.h"
#include "ExternalDictionaryLibraryHandlerFactory.h"
#include <Formats/FormatFactory.h>
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Common/BridgeProtocolVersion.h>
#include <IO/WriteHelpers.h>
#include <Poco/Net/HTMLForm.h>
#include <Poco/Net/HTTPServerRequest.h>
#include <Poco/Net/HTTPServerResponse.h>
#include <Poco/Net/HTMLForm.h>
#include <Poco/ThreadPool.h>
#include <Processors/Formats/IOutputFormat.h>
#include <Processors/Formats/IInputFormat.h>
#include <QueryPipeline/QueryPipeline.h>
#include <Processors/Executors/CompletedPipelineExecutor.h>
#include <Processors/Executors/PullingPipelineExecutor.h>
#include <Processors/Formats/IInputFormat.h>
#include <Processors/Formats/IOutputFormat.h>
#include <Processors/Sources/SourceFromSingleChunk.h>
#include <QueryPipeline/Pipe.h>
#include <QueryPipeline/QueryPipeline.h>
#include <Server/HTTP/HTMLForm.h>
#include <IO/ReadBufferFromString.h>
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
#include <Formats/NativeReader.h>
#include <Formats/NativeWriter.h>
#include <DataTypes/DataTypesNumber.h>
namespace DB
@ -31,7 +38,7 @@ namespace ErrorCodes
namespace
{
void processError(HTTPServerResponse & response, const std::string & message)
void processError(HTTPServerResponse & response, const String & message)
{
response.setStatusAndReason(HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
@ -41,7 +48,7 @@ namespace
LOG_WARNING(&Poco::Logger::get("LibraryBridge"), fmt::runtime(message));
}
std::shared_ptr<Block> parseColumns(std::string && column_string)
std::shared_ptr<Block> parseColumns(String && column_string)
{
auto sample_block = std::make_shared<Block>();
auto names_and_types = NamesAndTypesList::parse(column_string);
@ -59,10 +66,10 @@ namespace
return ids;
}
std::vector<std::string> parseNamesFromBinary(const std::string & names_string)
std::vector<String> parseNamesFromBinary(const String & names_string)
{
ReadBufferFromString buf(names_string);
std::vector<std::string> names;
std::vector<String> names;
readVectorBinary(names, buf);
return names;
}
@ -79,13 +86,15 @@ static void writeData(Block data, OutputFormatPtr format)
executor.execute();
}
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
, keep_alive_timeout(keep_alive_timeout_)
, log(&Poco::Logger::get("ExternalDictionaryLibraryBridgeRequestHandler"))
{
}
void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
LOG_TRACE(log, "Request URI: {}", request.getURI());
@ -97,7 +106,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
version = 0; /// assumed version for too old servers which do not send a version
else
{
String version_str = params.get("version");
const String & version_str = params.get("version");
if (!tryParse(version, version_str))
{
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
@ -124,8 +133,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string method = params.get("method");
std::string dictionary_id = params.get("dictionary_id");
const String & method = params.get("method");
const String & dictionary_id = params.get("dictionary_id");
LOG_TRACE(log, "Library method: '{}', dictionary id: {}", method, dictionary_id);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
@ -141,7 +150,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string from_dictionary_id = params.get("from_dictionary_id");
const String & from_dictionary_id = params.get("from_dictionary_id");
bool cloned = false;
cloned = ExternalDictionaryLibraryHandlerFactory::instance().clone(from_dictionary_id, dictionary_id);
@ -166,7 +175,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string library_path = params.get("library_path");
const String & library_path = params.get("library_path");
if (!params.has("library_settings"))
{
@ -174,10 +183,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
const auto & settings_string = params.get("library_settings");
const String & settings_string = params.get("library_settings");
LOG_DEBUG(log, "Parsing library settings from binary string");
std::vector<std::string> library_settings = parseNamesFromBinary(settings_string);
std::vector<String> library_settings = parseNamesFromBinary(settings_string);
/// Needed for library dictionary
if (!params.has("attributes_names"))
@ -186,10 +195,10 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
const auto & attributes_string = params.get("attributes_names");
const String & attributes_string = params.get("attributes_names");
LOG_DEBUG(log, "Parsing attributes names from binary string");
std::vector<std::string> attributes_names = parseNamesFromBinary(attributes_string);
std::vector<String> attributes_names = parseNamesFromBinary(attributes_string);
/// Needed to parse block from binary string format
if (!params.has("sample_block"))
@ -197,7 +206,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
processError(response, "No 'sample_block' in request URL");
return;
}
std::string sample_block_string = params.get("sample_block");
String sample_block_string = params.get("sample_block");
std::shared_ptr<Block> sample_block;
try
@ -297,7 +306,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
return;
}
std::string requested_block_string = params.get("requested_block_sample");
String requested_block_string = params.get("requested_block_sample");
std::shared_ptr<Block> requested_sample_block;
try
@ -332,7 +341,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
}
else
{
LOG_WARNING(log, "Unknown library method: '{}'", method);
processError(response, "Unknown library method '" + method + "'");
LOG_ERROR(log, "Unknown library method: '{}'", method);
}
}
catch (...)
@ -362,6 +372,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
}
}
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
@ -369,6 +380,7 @@ ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExi
{
}
void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
try
@ -382,7 +394,7 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
return;
}
std::string dictionary_id = params.get("dictionary_id");
const String & dictionary_id = params.get("dictionary_id");
auto library_handler = ExternalDictionaryLibraryHandlerFactory::instance().get(dictionary_id);
@ -399,4 +411,199 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
}
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(&Poco::Logger::get("CatBoostLibraryBridgeRequestHandler"))
{
}
void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
std::lock_guard lock(mutex);
LOG_TRACE(log, "Request URI: {}", request.getURI());
HTMLForm params(getContext()->getSettingsRef(), request);
size_t version;
if (!params.has("version"))
version = 0; /// assumed version for too old servers which do not send a version
else
{
const String & version_str = params.get("version");
if (!tryParse(version, version_str))
{
processError(response, "Unable to parse 'version' string in request URL: '" + version_str + "' Check if the server and library-bridge have the same version.");
return;
}
}
if (version != LIBRARY_BRIDGE_PROTOCOL_VERSION)
{
/// backwards compatibility is considered unnecessary for now, just let the user know that the server and the bridge must be upgraded together
processError(response, "Server and library-bridge have different versions: '" + std::to_string(version) + "' vs. '" + std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION) + "'");
return;
}
if (!params.has("method"))
{
processError(response, "No 'method' in request URL");
return;
}
const String & method = params.get("method");
LOG_TRACE(log, "Library method: '{}'", method);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
try
{
if (method == "catboost_GetTreeCount")
{
auto & read_buf = request.getStream();
params.read(read_buf);
if (!params.has("library_path"))
{
processError(response, "No 'library_path' in request URL");
return;
}
const String & library_path = params.get("library_path");
if (!params.has("model_path"))
{
processError(response, "No 'model_path' in request URL");
return;
}
const String & model_path = params.get("model_path");
CatBoostLibraryHandlerFactory::instance().remove(model_path);
CatBoostLibraryHandlerFactory::instance().create(library_path, model_path);
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
if (!catboost_handler)
{
processError(response, "CatBoost library is not loaded for model " + model_path);
return;
}
size_t tree_count = catboost_handler->getTreeCount();
writeIntBinary(tree_count, out);
}
else if (method == "catboost_libEvaluate")
{
auto & read_buf = request.getStream();
params.read(read_buf);
if (!params.has("model_path"))
{
processError(response, "No 'model_path' in request URL");
return;
}
const String & model_path = params.get("model_path");
if (!params.has("data"))
{
processError(response, "No 'data' in request URL");
return;
}
const String & data = params.get("data");
ReadBufferFromString string_read_buf(data);
NativeReader deserializer(string_read_buf, /*server_revision*/ 0);
Block block_read = deserializer.read();
Columns col_ptrs = block_read.getColumns();
ColumnRawPtrs col_raw_ptrs;
for (const auto & p : col_ptrs)
col_raw_ptrs.push_back(&*p);
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
if (!catboost_handler)
{
processError(response, "CatBoost library is not loaded for model" + model_path);
return;
}
ColumnPtr res_col = catboost_handler->evaluate(col_raw_ptrs);
DataTypePtr res_col_type = std::make_shared<DataTypeFloat64>();
String res_col_name = "res_col";
ColumnsWithTypeAndName res_cols_with_type_and_name = {{res_col, res_col_type, res_col_name}};
Block block_write(res_cols_with_type_and_name);
NativeWriter serializer{out, /*client_revision*/ 0, block_write};
serializer.write(block_write);
}
else
{
processError(response, "Unknown library method '" + method + "'");
LOG_ERROR(log, "Unknown library method: '{}'", method);
}
}
catch (...)
{
auto message = getCurrentExceptionMessage(true);
LOG_ERROR(log, "Failed to process request. Error: {}", message);
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR, message); // can't call process_error, because of too soon response sending
try
{
writeStringBinary(message, out);
out.finalize();
}
catch (...)
{
tryLogCurrentException(log);
}
}
try
{
out.finalize();
}
catch (...)
{
tryLogCurrentException(log);
}
}
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(&Poco::Logger::get("CatBoostLibraryBridgeExistsHandler"))
{
}
void CatBoostLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{
try
{
LOG_TRACE(log, "Request URI: {}", request.getURI());
HTMLForm params(getContext()->getSettingsRef(), request);
String res = "1";
setResponseDefaultHeaders(response, keep_alive_timeout);
LOG_TRACE(log, "Sending ping response: {}", res);
response.sendBuffer(res.data(), res.size());
}
catch (...)
{
tryLogCurrentException("PingHandler");
}
}
}

View File

@ -1,9 +1,9 @@
#pragma once
#include <Common/logger_useful.h>
#include <Interpreters/Context.h>
#include <Server/HTTP/HTTPRequestHandler.h>
#include <Common/logger_useful.h>
#include "ExternalDictionaryLibraryHandler.h"
#include <mutex>
namespace DB
@ -26,11 +26,12 @@ public:
private:
static constexpr inline auto FORMAT = "RowBinary";
const size_t keep_alive_timeout;
Poco::Logger * log;
size_t keep_alive_timeout;
};
// Handler for checking if the external dictionary library is loaded (used for handshake)
class ExternalDictionaryLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
@ -43,4 +44,43 @@ private:
Poco::Logger * log;
};
/// Handler for requests to catboost library. The call protocol is as follows:
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and
/// a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was
/// loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler
/// and finally sends the result back to the server.
/// Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in
/// the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than
/// once and if "model.bin" was updated in between.
/// (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on.
/// Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library
/// handler for the given model path is expected to be already loaded by Step (1).
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
private:
std::mutex mutex;
const size_t keep_alive_timeout;
Poco::Logger * log;
};
// Handler for pinging the library-bridge for catboost access (used for handshake)
class CatBoostLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
private:
const size_t keep_alive_timeout;
Poco::Logger * log;
};
}

View File

@ -0,0 +1,118 @@
#include "CatBoostLibraryBridgeHelper.h"
#include <Columns/ColumnsNumber.h>
#include <Common/escapeForFileName.h>
#include <Core/Block.h>
#include <DataTypes/DataTypesNumber.h>
#include <Formats/NativeReader.h>
#include <Formats/NativeWriter.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
#include <Poco/Net/HTTPRequest.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(
ContextPtr context_,
std::string_view library_path_,
std::string_view model_path_)
: LibraryBridgeHelper(context_->getGlobalContext())
, library_path(library_path_)
, model_path(model_path_)
{
}
Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const
{
auto uri = createBaseURI();
uri.setPath(PING_HANDLER);
return uri;
}
Poco::URI CatBoostLibraryBridgeHelper::getMainURI() const
{
auto uri = createBaseURI();
uri.setPath(MAIN_HANDLER);
return uri;
}
Poco::URI CatBoostLibraryBridgeHelper::createRequestURI(const String & method) const
{
auto uri = getMainURI();
uri.addQueryParameter("version", std::to_string(LIBRARY_BRIDGE_PROTOCOL_VERSION));
uri.addQueryParameter("method", method);
return uri;
}
bool CatBoostLibraryBridgeHelper::bridgeHandShake()
{
String result;
try
{
ReadWriteBufferFromHTTP buf(getPingURI(), Poco::Net::HTTPRequest::HTTP_GET, {}, http_timeouts, credentials);
readString(result, buf);
}
catch (...)
{
tryLogCurrentException(log);
return false;
}
if (result != "1")
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected message from library bridge: {}. Check that bridge and server have the same version.", result);
return true;
}
size_t CatBoostLibraryBridgeHelper::getTreeCount()
{
startBridgeSync();
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_GETTREECOUNT_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[this](std::ostream & os)
{
os << "library_path=" << escapeForFileName(library_path) << "&";
os << "model_path=" << escapeForFileName(model_path);
},
http_timeouts, credentials);
size_t res;
readIntBinary(res, buf);
return res;
}
ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns)
{
startBridgeSync();
WriteBufferFromOwnString string_write_buf;
Block block(columns);
NativeWriter serializer(string_write_buf, /*client_revision*/ 0, block);
serializer.write(block);
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_LIB_EVALUATE_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[this, serialized = string_write_buf.str()](std::ostream & os)
{
os << "model_path=" << escapeForFileName(model_path) << "&";
os << "data=" << escapeForFileName(serialized);
},
http_timeouts, credentials);
NativeReader deserializer(buf, /*server_revision*/ 0);
Block block_read = deserializer.read();
return block_read.getColumns()[0];
}
}

View File

@ -0,0 +1,42 @@
#pragma once
#include <BridgeHelper/LibraryBridgeHelper.h>
#include <DataTypes/IDataType.h>
#include <IO/ReadWriteBufferFromHTTP.h>
#include <Interpreters/Context.h>
#include <Poco/URI.h>
namespace DB
{
class CatBoostLibraryBridgeHelper : public LibraryBridgeHelper
{
public:
static constexpr inline auto PING_HANDLER = "/catboost_ping";
static constexpr inline auto MAIN_HANDLER = "/catboost_request";
CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view library_path_, std::string_view model_path_);
size_t getTreeCount();
ColumnPtr evaluate(const ColumnsWithTypeAndName & columns);
protected:
Poco::URI getPingURI() const override;
Poco::URI getMainURI() const override;
bool bridgeHandShake() override;
private:
static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount";
static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate";
Poco::URI createRequestURI(const String & method) const;
const String library_path;
const String model_path;
};
}

View File

@ -12,8 +12,8 @@
namespace DB
{
/// Common base class for XDBC and Library bridge helpers.
/// Contains helper methods to check/start bridge sync.
/// Base class for server-side bridge helpers, e.g. xdbc-bridge and library-bridge.
/// Contains helper methods to check/start bridge sync
class IBridgeHelper: protected WithContext
{

View File

@ -176,10 +176,10 @@ static void tryLogCurrentExceptionImpl(Poco::Logger * logger, const std::string
void tryLogCurrentException(const char * log_name, const std::string & start_of_message)
{
/// Under high memory pressure, any new allocation will definitelly lead
/// to MEMORY_LIMIT_EXCEEDED exception.
/// Under high memory pressure, new allocations throw a
/// MEMORY_LIMIT_EXCEEDED exception.
///
/// And in this case the exception will not be logged, so let's block the
/// In this case the exception will not be logged, so let's block the
/// MemoryTracker until the exception will be logged.
LockMemoryExceptionInThread lock_memory_tracker(VariableContext::Global);
@ -189,8 +189,8 @@ void tryLogCurrentException(const char * log_name, const std::string & start_of_
void tryLogCurrentException(Poco::Logger * logger, const std::string & start_of_message)
{
/// Under high memory pressure, any new allocation will definitelly lead
/// to MEMORY_LIMIT_EXCEEDED exception.
/// Under high memory pressure, new allocations throw a
/// MEMORY_LIMIT_EXCEEDED exception.
///
/// And in this case the exception will not be logged, so let's block the
/// MemoryTracker until the exception will be logged.

View File

@ -0,0 +1,182 @@
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
#include <BridgeHelper/IBridgeHelper.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnsNumber.h>
#include <Common/assert_cast.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
namespace ErrorCodes
{
extern const int FILE_DOESNT_EXIST;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
extern const int ILLEGAL_COLUMN;
}
/// Evaluate CatBoost model.
/// - Arguments: float features first, then categorical features.
/// - Result: Float64.
class FunctionCatBoostEvaluate final : public IFunction, WithContext
{
private:
mutable std::unique_ptr<CatBoostLibraryBridgeHelper> bridge_helper;
public:
static constexpr auto name = "catboostEvaluate";
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionCatBoostEvaluate>(context_); }
explicit FunctionCatBoostEvaluate(ContextPtr context_) : WithContext(context_) {}
String getName() const override { return name; }
bool isVariadic() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool isDeterministic() const override { return false; }
bool useDefaultImplementationForNulls() const override { return false; }
size_t getNumberOfArguments() const override { return 0; }
void initBridge(const ColumnConst * name_col) const
{
String library_path = getContext()->getConfigRef().getString("catboost_lib_path");
if (!std::filesystem::exists(library_path))
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load library {}: file doesn't exist", library_path);
String model_path = name_col->getValue<String>();
if (!std::filesystem::exists(model_path))
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load model {}: file doesn't exist", model_path);
bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), library_path, model_path);
}
DataTypePtr getReturnTypeFromLibraryBridge() const
{
size_t tree_count = bridge_helper->getTreeCount();
auto type = std::make_shared<DataTypeFloat64>();
if (tree_count == 1)
return type;
DataTypes types(tree_count, type);
return std::make_shared<DataTypeTuple>(types);
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() < 2)
throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} expects at least 2 arguments", getName());
if (!isString(arguments[0].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of first argument of function {}, expected a string.", arguments[0].type->getName(), getName());
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!name_col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
initBridge(name_col);
auto type = getReturnTypeFromLibraryBridge();
bool has_nullable = false;
for (size_t i = 1; i < arguments.size(); ++i)
has_nullable = has_nullable || arguments[i].type->isNullable();
if (has_nullable)
{
if (const auto * tuple = typeid_cast<const DataTypeTuple *>(type.get()))
{
auto elements = tuple->getElements();
for (auto & element : elements)
element = makeNullable(element);
type = std::make_shared<DataTypeTuple>(elements);
}
else
type = makeNullable(type);
}
return type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
{
const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!name_col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
ColumnRawPtrs column_ptrs;
Columns materialized_columns;
ColumnPtr null_map;
ColumnsWithTypeAndName feature_arguments(arguments.begin() + 1, arguments.end());
for (auto & arg : feature_arguments)
{
if (auto full_column = arg.column->convertToFullColumnIfConst())
{
materialized_columns.push_back(full_column);
arg.column = full_column;
}
if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(&*arg.column))
{
if (!null_map)
null_map = col_nullable->getNullMapColumnPtr();
else
{
auto mut_null_map = IColumn::mutate(std::move(null_map));
NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mut_null_map).getData();
const NullMap & src_null_map = col_nullable->getNullMapColumn().getData();
for (size_t i = 0, size = result_null_map.size(); i < size; ++i)
if (src_null_map[i])
result_null_map[i] = 1;
null_map = std::move(mut_null_map);
}
arg.column = col_nullable->getNestedColumn().getPtr();
arg.type = static_cast<const DataTypeNullable &>(*arg.type).getNestedType();
}
}
auto res = bridge_helper->evaluate(feature_arguments);
if (null_map)
{
if (const auto * tuple = typeid_cast<const ColumnTuple *>(res.get()))
{
auto nested = tuple->getColumns();
for (auto & col : nested)
col = ColumnNullable::create(col, null_map);
res = ColumnTuple::create(nested);
}
else
res = ColumnNullable::create(res, null_map);
}
return res;
}
};
REGISTER_FUNCTION(CatBoostEvaluate)
{
factory.registerFunction<FunctionCatBoostEvaluate>();
}
}

View File

@ -0,0 +1,3 @@
<clickhouse>
<catboost_lib_path>/etc/clickhouse-server/model/libcatboostmodel.so</catboost_lib_path>
</clickhouse>

View File

@ -0,0 +1,318 @@
import os
import sys
import time
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance(
"instance", stay_alive=True, main_configs=["config/models_config.xml"]
)
@pytest.fixture(scope="module")
def ch_cluster():
try:
cluster.start()
os.system(
"docker cp {local} {cont_id}:{dist}".format(
local=os.path.join(SCRIPT_DIR, "model/."),
cont_id=instance.docker_id,
dist="/etc/clickhouse-server/model",
)
)
instance.restart_clickhouse()
yield cluster
finally:
cluster.shutdown()
# ---------------------------------------------------------------------------
# simple_model.bin has 2 float features and 9 categorical features
def testConstantFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
expected = "-1.930268705869267\n"
assert result == expected
def testNonConstantFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
instance.query("DROP TABLE IF EXISTS T;")
instance.query(
"CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;"
)
instance.query("INSERT INTO T VALUES(0, 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11) from T;"
)
expected = "-1.930268705869267\n"
assert result == expected
instance.query("DROP TABLE IF EXISTS T;")
def testModelPathIsNotAConstString(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
err = instance.query_and_get_error(
"select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert (
"Illegal type UInt8 of first argument of function catboostEvaluate, expected a string"
in err
)
instance.query("DROP TABLE IF EXISTS T;")
instance.query("CREATE TABLE T(ID UInt32, A String) ENGINE MergeTree ORDER BY ID")
instance.query("INSERT INTO T VALUES(0, 'test');")
err = instance.query_and_get_error(
"select catboostEvaluate(A, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) FROM T;"
)
assert (
"First argument of function catboostEvaluate must be a constant string" in err
)
instance.query("DROP TABLE IF EXISTS T;")
def testWrongNumberOfFeatureArguments(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');"
)
assert "Function catboostEvaluate expects at least 2 arguments" in err
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2);"
)
assert (
"Number of columns is different with number of features: columns size 2 float features size 2 + cat features size 9"
in err
)
def testFloatFeatureMustBeNumeric(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert "Column 1 should be numeric to make float feature" in err
def testCategoricalFeatureMustBeNumericOrString(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);"
)
assert "Column 7 should be numeric or string" in err
def testOnLowCardinalityFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
# same but on domain-compressed data
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));"
)
expected = "-1.930268705869267\n"
assert result == expected
def testOnNullableFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));"
)
expected = "-1.930268705869267\n"
assert result == expected
# Actual NULLs are disallowed
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL), toNullable(NULL));"
)
assert "Column 0 should be numeric to make float feature" in err
def testInvalidLibraryPath(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
# temporarily move library elsewhere
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/libcatboostmodel.so /etc/clickhouse-server/model/nonexistant.so",
]
)
err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert (
"Can't load library /etc/clickhouse-server/model/libcatboostmodel.so: file doesn't exist"
in err
)
# restore
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/nonexistant.so /etc/clickhouse-server/model/libcatboostmodel.so",
]
)
def testInvalidModelPath(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
err = instance.query_and_get_error(
"select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert "Can't load model : file doesn't exist" in err
err = instance.query_and_get_error(
"select catboostEvaluate('model_non_existant.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert "Can't load model model_non_existant.bin: file doesn't exist" in err
def testRecoveryAfterCrash(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
expected = "-1.930268705869267\n"
assert result == expected
instance.exec_in_container(
["bash", "-c", "kill -9 `pidof clickhouse-library-bridge`"], user="root"
)
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
assert result == expected
# ---------------------------------------------------------------------------
# amazon_model.bin has 0 float features and 9 categorical features
def testAmazonModelSingleRow(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
)
expected = "0.7774665009089274\n"
assert result == expected
def testAmazonModelManyRows(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("drop table if exists amazon")
result = instance.query(
"create table amazon ( DATE Date materialized today(), ACTION UInt8, RESOURCE UInt32, MGR_ID UInt32, ROLE_ROLLUP_1 UInt32, ROLE_ROLLUP_2 UInt32, ROLE_DEPTNAME UInt32, ROLE_TITLE UInt32, ROLE_FAMILY_DESC UInt32, ROLE_FAMILY UInt32, ROLE_CODE UInt32) engine = MergeTree order by DATE"
)
result = instance.query(
"insert into amazon select number % 256, number, number, number, number, number, number, number, number, number from numbers(7500)"
)
# First compute prediction, then as a very crude way to fingerprint and compare the result: sum and floor
# (the focus is to test that the exchange of large result sets between the server and the bridge works)
result = instance.query(
"SELECT floor(sum(catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', RESOURCE, MGR_ID, ROLE_ROLLUP_1, ROLE_ROLLUP_2, ROLE_DEPTNAME, ROLE_TITLE, ROLE_FAMILY_DESC, ROLE_FAMILY, ROLE_CODE))) FROM amazon"
)
expected = "5834\n"
assert result == expected
result = instance.query("drop table if exists amazon")
def testModelUpdate(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
result = instance.query(query)
expected = "-1.930268705869267\n"
assert result == expected
# simulate an update of the model: temporarily move the amazon model in place of the simple model
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/simple_model.bin.bak",
]
)
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/amazon_model.bin /etc/clickhouse-server/model/simple_model.bin",
]
)
# since the amazon model has a different number of features than the simple model, we should get an error
err = instance.query_and_get_error(query)
assert (
"Number of columns is different with number of features: columns size 11 float features size 0 + cat features size 9"
in err
)
# restore
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/simple_model.bin /etc/clickhouse-server/model/amazon_model.bin",
]
)
instance.exec_in_container(
[
"bash",
"-c",
"mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin",
]
)