added CatBoostModel [#CLICKHOUSE-3305]

This commit is contained in:
Nikolai Kochetov 2017-10-06 21:05:30 +03:00
parent 93e1401b35
commit e817de7e21
2 changed files with 126 additions and 3 deletions

View File

@ -1,14 +1,122 @@
#include <Dictionaries/CatBoostModel.h>
#include <boost/dll/import.hpp>
#include <mutex>
namespace DB
{
CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
namespace
{
struct CatBoostWrapperApi
{
typedef void ModelCalcerHandle;
ModelCalcerHandle * (* ModelCalcerCreate)();
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
const char * (* GetErrorString)();
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
const float ** floatFeatures, size_t floatFeaturesSize,
double * result, size_t resultSize);
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
const float ** floatFeatures, size_t floatFeaturesSize,
const char *** catFeatures, size_t catFeaturesSize,
double * result, size_t resultSize);
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
const float ** floatFeatures, size_t floatFeaturesSize,
const int ** catFeatures, size_t catFeaturesSize,
double * result, size_t resultSize);
int (* GetStringCatFeatureHash)(const char * data, size_t size);
int (* GetIntegerCatFeatureHash)(long long val);
};
class CatBoostWrapperHolder : public CatBoostWrapperApiProvider
{
public:
CatBoostWrapperHolder(const std::string & lib_path) : lib(lib_path), lib_path(lib_path) { initApi(); }
const CatBoostWrapperApi & getApi() const override { return api; }
const std::string & getCurrentPath() const { return lib_path; }
private:
CatBoostWrapperApi api;
std::string lib_path;
boost::dll::shared_library lib;
void initApi();
template <typename T>
void load(T& func, const std::string & name)
{
using Type = std::remove_pointer<T>::type;
func = lib.get<Type>(name);
}
};
void CatBoostWrapperHolder::initApi()
{
load(api.ModelCalcerCreate, "ModelCalcerCreate");
load(api.ModelCalcerDelete, "ModelCalcerDelete");
load(api.GetErrorString, "GetErrorString");
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
load(api.CalcModelPrediction, "CalcModelPrediction");
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
}
std::shared_ptr<CatBoostWrapperHolder> getCatBoostWrapperHolder(const std::string & lib_path)
{
static std::weak_ptr<CatBoostWrapperHolder> ptr;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
auto result = ptr.lock();
if (!result || result->getCurrentPath() != lib_path)
{
result = std::make_shared<CatBoostWrapperHolder>(lib_path);
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
ptr = result;
}
return result;
}
}
CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix, const std::string & lib_path)
: lifetime(config, config_prefix)
{
}
CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path,
const ExternalLoadableLifetime & lifetime)
: name(name), model_path(model_path), lifetime(lifetime)
{
try
{
api_provider = getCatBoostWrapperHolder(lib_path);
api = &api_provider->getApi();
}
catch (...)
{
creation_exception = std::current_exception();
}
}
const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
{
return lifetime;

View File

@ -7,10 +7,20 @@
namespace DB
{
struct CatBoostWrapperApi;
class CatBoostWrapperApiProvider
{
public:
virtual ~CatBoostWrapperApiProvider() = default;
virtual const CatBoostWrapperApi & getApi() const = 0;
};
class CatBoostModel : public IExternalLoadable
{
public:
CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
const std::string & lib_path);
const ExternalLoadableLifetime & getLifetime() const override;
@ -30,10 +40,15 @@ public:
void apply(const Columns & floatColumns, const Columns & catColumns, ColumnFloat64 & result);
private:
ExternalLoadableLifetime lifetime;
std::string name;
std::string model_path;
ExternalLoadableLifetime lifetime;
std::exception_ptr creation_exception;
std::shared_ptr<CatBoostWrapperApiProvider> api_provider;
const CatBoostWrapperApi * api;
CatBoostModel(const std::string & name, const std::string & model_path,
const std::string & lib_path, const ExternalLoadableLifetime & lifetime);
};
}