From e817de7e21c6806564ee36bb0c227bacfeb304b0 Mon Sep 17 00:00:00 2001 From: Nikolai Kochetov Date: Fri, 6 Oct 2017 21:05:30 +0300 Subject: [PATCH] added CatBoostModel [#CLICKHOUSE-3305] --- dbms/src/Dictionaries/CatBoostModel.cpp | 110 +++++++++++++++++++++++- dbms/src/Dictionaries/CatBoostModel.h | 19 +++- 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/dbms/src/Dictionaries/CatBoostModel.cpp b/dbms/src/Dictionaries/CatBoostModel.cpp index 4f75fc792a2..deb2a0e7af8 100644 --- a/dbms/src/Dictionaries/CatBoostModel.cpp +++ b/dbms/src/Dictionaries/CatBoostModel.cpp @@ -1,14 +1,122 @@ #include +#include +#include namespace DB { -CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix) +namespace +{ + +struct CatBoostWrapperApi +{ + typedef void ModelCalcerHandle; + + ModelCalcerHandle * (* ModelCalcerCreate)(); + + void (* ModelCalcerDelete)(ModelCalcerHandle * calcer); + + const char * (* GetErrorString)(); + + bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename); + + bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount, + const float ** floatFeatures, size_t floatFeaturesSize, + double * result, size_t resultSize); + + bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount, + const float ** floatFeatures, size_t floatFeaturesSize, + const char *** catFeatures, size_t catFeaturesSize, + double * result, size_t resultSize); + + bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount, + const float ** floatFeatures, size_t floatFeaturesSize, + const int ** catFeatures, size_t catFeaturesSize, + double * result, size_t resultSize); + + int (* GetStringCatFeatureHash)(const char * data, size_t size); + + int (* GetIntegerCatFeatureHash)(long long val); +}; + +class CatBoostWrapperHolder : public CatBoostWrapperApiProvider +{ +public: + CatBoostWrapperHolder(const std::string & lib_path) : lib(lib_path), lib_path(lib_path) { initApi(); } + + const CatBoostWrapperApi & getApi() const override { return api; } + const std::string & getCurrentPath() const { return lib_path; } + +private: + CatBoostWrapperApi api; + std::string lib_path; + boost::dll::shared_library lib; + + void initApi(); + + template + void load(T& func, const std::string & name) + { + using Type = std::remove_pointer::type; + func = lib.get(name); + } +}; + +void CatBoostWrapperHolder::initApi() +{ + load(api.ModelCalcerCreate, "ModelCalcerCreate"); + load(api.ModelCalcerDelete, "ModelCalcerDelete"); + load(api.GetErrorString, "GetErrorString"); + load(api.LoadFullModelFromFile, "LoadFullModelFromFile"); + load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat"); + load(api.CalcModelPrediction, "CalcModelPrediction"); + load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures"); + load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash"); + load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash"); +} + +std::shared_ptr getCatBoostWrapperHolder(const std::string & lib_path) +{ + static std::weak_ptr ptr; + static std::mutex mutex; + + std::lock_guard lock(mutex); + auto result = ptr.lock(); + + if (!result || result->getCurrentPath() != lib_path) + { + result = std::make_shared(lib_path); + /// This assignment is not atomic, which prevents from creating lock only inside 'if'. + ptr = result; + } + + return result; +} + +} + +CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix, const std::string & lib_path) : lifetime(config, config_prefix) { } +CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path, + const ExternalLoadableLifetime & lifetime) + : name(name), model_path(model_path), lifetime(lifetime) +{ + try + { + api_provider = getCatBoostWrapperHolder(lib_path); + api = &api_provider->getApi(); + } + catch (...) + { + creation_exception = std::current_exception(); + } +} + const ExternalLoadableLifetime & CatBoostModel::getLifetime() const { return lifetime; diff --git a/dbms/src/Dictionaries/CatBoostModel.h b/dbms/src/Dictionaries/CatBoostModel.h index 3163532b009..adacffab05b 100644 --- a/dbms/src/Dictionaries/CatBoostModel.h +++ b/dbms/src/Dictionaries/CatBoostModel.h @@ -7,10 +7,20 @@ namespace DB { +struct CatBoostWrapperApi; +class CatBoostWrapperApiProvider +{ +public: + virtual ~CatBoostWrapperApiProvider() = default; + virtual const CatBoostWrapperApi & getApi() const = 0; +}; + + class CatBoostModel : public IExternalLoadable { public: - CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix); + CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix, + const std::string & lib_path); const ExternalLoadableLifetime & getLifetime() const override; @@ -30,10 +40,15 @@ public: void apply(const Columns & floatColumns, const Columns & catColumns, ColumnFloat64 & result); private: - ExternalLoadableLifetime lifetime; std::string name; + std::string model_path; + ExternalLoadableLifetime lifetime; std::exception_ptr creation_exception; + std::shared_ptr api_provider; + const CatBoostWrapperApi * api; + CatBoostModel(const std::string & name, const std::string & model_path, + const std::string & lib_path, const ExternalLoadableLifetime & lifetime); }; }