mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
60f9f6855d
This commit moves the catboost model evaluation out of the server process into the library-bridge binary. This serves two goals: On the one hand, crashes / memory corruptions of the catboost library no longer affect the server. On the other hand, we can forbid loading dynamic libraries in the server (catboost was the last consumer of this functionality), thus improving security. SQL syntax: SELECT catboostEvaluate('/path/to/model.bin', FEAT_1, ..., FEAT_N) > 0 AS prediction, ACTION AS target FROM amazon_train LIMIT 10 Required configuration: <catboost_lib_path>/path/to/libcatboostmodel.so</catboost_lib_path> *** Implementation Details *** The internal protocol between the server and the library-bridge is simple: - HTTP GET on path "/extdict_ping": A ping, used during the handshake to check if the library-bridge runs. - HTTP POST on path "extdict_request" (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler and finally sends the result back to the server. Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than once and if "model.bin" was updated in between. (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on. Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library handler for the given model path is expected to be already loaded by Step (1). Fixes #27870
72 lines
2.4 KiB
C++
72 lines
2.4 KiB
C++
#pragma once
|
|
|
|
#include "CatBoostLibraryAPI.h"
|
|
|
|
#include <Columns/ColumnFixedString.h>
|
|
#include <Columns/ColumnString.h>
|
|
#include <Columns/ColumnVector.h>
|
|
#include <Columns/ColumnsNumber.h>
|
|
#include <Columns/IColumn.h>
|
|
#include <Common/SharedLibrary.h>
|
|
#include <base/defines.h>
|
|
|
|
#include <mutex>
|
|
|
|
namespace DB
|
|
{
|
|
|
|
/// Abstracts access to the CatBoost shared library.
|
|
class CatBoostLibraryHandler
|
|
{
|
|
public:
|
|
/// Holds pointers to CatBoost library functions
|
|
struct APIHolder
|
|
{
|
|
explicit APIHolder(SharedLibrary & lib);
|
|
|
|
// NOLINTBEGIN(readability-identifier-naming)
|
|
CatBoostLibraryAPI::ModelCalcerCreateFunc ModelCalcerCreate;
|
|
CatBoostLibraryAPI::ModelCalcerDeleteFunc ModelCalcerDelete;
|
|
CatBoostLibraryAPI::GetErrorStringFunc GetErrorString;
|
|
CatBoostLibraryAPI::LoadFullModelFromFileFunc LoadFullModelFromFile;
|
|
CatBoostLibraryAPI::CalcModelPredictionFlatFunc CalcModelPredictionFlat;
|
|
CatBoostLibraryAPI::CalcModelPredictionFunc CalcModelPrediction;
|
|
CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc CalcModelPredictionWithHashedCatFeatures;
|
|
CatBoostLibraryAPI::GetStringCatFeatureHashFunc GetStringCatFeatureHash;
|
|
CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc GetIntegerCatFeatureHash;
|
|
CatBoostLibraryAPI::GetFloatFeaturesCountFunc GetFloatFeaturesCount;
|
|
CatBoostLibraryAPI::GetCatFeaturesCountFunc GetCatFeaturesCount;
|
|
CatBoostLibraryAPI::GetTreeCountFunc GetTreeCount;
|
|
CatBoostLibraryAPI::GetDimensionsCountFunc GetDimensionsCount;
|
|
// NOLINTEND(readability-identifier-naming)
|
|
};
|
|
|
|
CatBoostLibraryHandler(
|
|
const std::string & library_path,
|
|
const std::string & model_path);
|
|
|
|
~CatBoostLibraryHandler();
|
|
|
|
size_t getTreeCount() const;
|
|
|
|
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
|
|
|
|
private:
|
|
const SharedLibraryPtr library;
|
|
const APIHolder api;
|
|
|
|
mutable std::mutex mutex;
|
|
|
|
CatBoostLibraryAPI::ModelCalcerHandle * model_calcer_handle TSA_GUARDED_BY(mutex) TSA_PT_GUARDED_BY(mutex);
|
|
|
|
size_t float_features_count TSA_GUARDED_BY(mutex);
|
|
size_t cat_features_count TSA_GUARDED_BY(mutex);
|
|
size_t tree_count TSA_GUARDED_BY(mutex);
|
|
|
|
ColumnFloat64::MutablePtr evalImpl(const ColumnRawPtrs & columns, bool cat_features_are_strings) const TSA_REQUIRES(mutex);
|
|
};
|
|
|
|
using CatBoostLibraryHandlerPtr = std::shared_ptr<CatBoostLibraryHandler>;
|
|
|
|
}
|