mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 08:02:02 +00:00
added CatBoostModel [#CLICKHOUSE-3305]
This commit is contained in:
parent
93e1401b35
commit
e817de7e21
@ -1,14 +1,122 @@
|
||||
#include <Dictionaries/CatBoostModel.h>
|
||||
#include <boost/dll/import.hpp>
|
||||
#include <mutex>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
|
||||
namespace
|
||||
{
|
||||
|
||||
struct CatBoostWrapperApi
|
||||
{
|
||||
typedef void ModelCalcerHandle;
|
||||
|
||||
ModelCalcerHandle * (* ModelCalcerCreate)();
|
||||
|
||||
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
|
||||
|
||||
const char * (* GetErrorString)();
|
||||
|
||||
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
|
||||
|
||||
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
const char *** catFeatures, size_t catFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
const int ** catFeatures, size_t catFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
int (* GetStringCatFeatureHash)(const char * data, size_t size);
|
||||
|
||||
int (* GetIntegerCatFeatureHash)(long long val);
|
||||
};
|
||||
|
||||
class CatBoostWrapperHolder : public CatBoostWrapperApiProvider
|
||||
{
|
||||
public:
|
||||
CatBoostWrapperHolder(const std::string & lib_path) : lib(lib_path), lib_path(lib_path) { initApi(); }
|
||||
|
||||
const CatBoostWrapperApi & getApi() const override { return api; }
|
||||
const std::string & getCurrentPath() const { return lib_path; }
|
||||
|
||||
private:
|
||||
CatBoostWrapperApi api;
|
||||
std::string lib_path;
|
||||
boost::dll::shared_library lib;
|
||||
|
||||
void initApi();
|
||||
|
||||
template <typename T>
|
||||
void load(T& func, const std::string & name)
|
||||
{
|
||||
using Type = std::remove_pointer<T>::type;
|
||||
func = lib.get<Type>(name);
|
||||
}
|
||||
};
|
||||
|
||||
void CatBoostWrapperHolder::initApi()
|
||||
{
|
||||
load(api.ModelCalcerCreate, "ModelCalcerCreate");
|
||||
load(api.ModelCalcerDelete, "ModelCalcerDelete");
|
||||
load(api.GetErrorString, "GetErrorString");
|
||||
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
|
||||
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
|
||||
load(api.CalcModelPrediction, "CalcModelPrediction");
|
||||
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
|
||||
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
|
||||
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
|
||||
}
|
||||
|
||||
std::shared_ptr<CatBoostWrapperHolder> getCatBoostWrapperHolder(const std::string & lib_path)
|
||||
{
|
||||
static std::weak_ptr<CatBoostWrapperHolder> ptr;
|
||||
static std::mutex mutex;
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
auto result = ptr.lock();
|
||||
|
||||
if (!result || result->getCurrentPath() != lib_path)
|
||||
{
|
||||
result = std::make_shared<CatBoostWrapperHolder>(lib_path);
|
||||
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
|
||||
ptr = result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & config_prefix, const std::string & lib_path)
|
||||
: lifetime(config, config_prefix)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path,
|
||||
const ExternalLoadableLifetime & lifetime)
|
||||
: name(name), model_path(model_path), lifetime(lifetime)
|
||||
{
|
||||
try
|
||||
{
|
||||
api_provider = getCatBoostWrapperHolder(lib_path);
|
||||
api = &api_provider->getApi();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
creation_exception = std::current_exception();
|
||||
}
|
||||
}
|
||||
|
||||
const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
|
||||
{
|
||||
return lifetime;
|
||||
|
@ -7,10 +7,20 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct CatBoostWrapperApi;
|
||||
class CatBoostWrapperApiProvider
|
||||
{
|
||||
public:
|
||||
virtual ~CatBoostWrapperApiProvider() = default;
|
||||
virtual const CatBoostWrapperApi & getApi() const = 0;
|
||||
};
|
||||
|
||||
|
||||
class CatBoostModel : public IExternalLoadable
|
||||
{
|
||||
public:
|
||||
CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
|
||||
CatBoostModel(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||
const std::string & lib_path);
|
||||
|
||||
const ExternalLoadableLifetime & getLifetime() const override;
|
||||
|
||||
@ -30,10 +40,15 @@ public:
|
||||
void apply(const Columns & floatColumns, const Columns & catColumns, ColumnFloat64 & result);
|
||||
|
||||
private:
|
||||
ExternalLoadableLifetime lifetime;
|
||||
std::string name;
|
||||
std::string model_path;
|
||||
ExternalLoadableLifetime lifetime;
|
||||
std::exception_ptr creation_exception;
|
||||
std::shared_ptr<CatBoostWrapperApiProvider> api_provider;
|
||||
const CatBoostWrapperApi * api;
|
||||
|
||||
CatBoostModel(const std::string & name, const std::string & model_path,
|
||||
const std::string & lib_path, const ExternalLoadableLifetime & lifetime);
|
||||
};
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user