2017-10-06 14:48:33 +00:00
|
|
|
#pragma once
|
|
|
|
#include <Interpreters/IExternalLoadable.h>
|
|
|
|
#include <Columns/IColumn.h>
|
|
|
|
#include <Columns/ColumnsNumber.h>
|
|
|
|
|
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// CatBoost wrapper interface functions.
|
2017-10-26 12:18:37 +00:00
|
|
|
struct CatBoostWrapperAPI;
|
|
|
|
class CatBoostWrapperAPIProvider
|
2017-10-06 18:05:30 +00:00
|
|
|
{
|
|
|
|
public:
|
2017-10-26 12:18:37 +00:00
|
|
|
virtual ~CatBoostWrapperAPIProvider() = default;
|
|
|
|
virtual const CatBoostWrapperAPI & getAPI() const = 0;
|
2017-10-06 18:05:30 +00:00
|
|
|
};
|
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// CatBoost model interface.
|
2017-10-09 20:13:44 +00:00
|
|
|
class ICatBoostModel
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
virtual ~ICatBoostModel() = default;
|
2017-10-20 15:12:34 +00:00
|
|
|
/// Evaluate model. Use first `float_features_count` columns as float features,
|
|
|
|
/// the others `cat_features_count` as categorical features.
|
2017-10-31 11:18:09 +00:00
|
|
|
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const = 0;
|
|
|
|
|
|
|
|
virtual size_t getFloatFeaturesCount() const = 0;
|
|
|
|
virtual size_t getCatFeaturesCount() const = 0;
|
2017-10-09 20:13:44 +00:00
|
|
|
};
|
2017-10-06 18:05:30 +00:00
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// General ML model evaluator interface.
|
2017-10-17 10:44:46 +00:00
|
|
|
class IModel : public IExternalLoadable
|
2017-10-06 14:48:33 +00:00
|
|
|
{
|
|
|
|
public:
|
2017-10-26 14:08:05 +00:00
|
|
|
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const = 0;
|
2017-10-17 10:44:46 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
class CatBoostModel : public IModel
|
|
|
|
{
|
|
|
|
public:
|
2017-10-31 11:18:09 +00:00
|
|
|
CatBoostModel(std::string name, std::string model_path,
|
|
|
|
std::string lib_path, const ExternalLoadableLifetime & lifetime);
|
2017-10-17 10:44:46 +00:00
|
|
|
|
2017-10-26 14:08:05 +00:00
|
|
|
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const override;
|
2017-10-20 15:12:34 +00:00
|
|
|
|
|
|
|
size_t getFloatFeaturesCount() const;
|
|
|
|
size_t getCatFeaturesCount() const;
|
|
|
|
|
|
|
|
/// IExternalLoadable interface.
|
2017-10-06 14:48:33 +00:00
|
|
|
|
|
|
|
const ExternalLoadableLifetime & getLifetime() const override;
|
|
|
|
|
|
|
|
std::string getName() const override { return name; }
|
|
|
|
|
|
|
|
bool supportUpdates() const override { return true; }
|
|
|
|
|
|
|
|
bool isModified() const override;
|
|
|
|
|
2017-10-26 13:36:01 +00:00
|
|
|
std::unique_ptr<IExternalLoadable> clone() const override;
|
2017-10-06 14:48:33 +00:00
|
|
|
|
|
|
|
std::exception_ptr getCreationException() const override { return creation_exception; }
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::string name;
|
2017-10-06 18:05:30 +00:00
|
|
|
std::string model_path;
|
2017-10-09 20:13:44 +00:00
|
|
|
std::string lib_path;
|
2017-10-06 18:05:30 +00:00
|
|
|
ExternalLoadableLifetime lifetime;
|
2017-10-06 14:48:33 +00:00
|
|
|
std::exception_ptr creation_exception;
|
2017-10-26 12:18:37 +00:00
|
|
|
std::shared_ptr<CatBoostWrapperAPIProvider> api_provider;
|
|
|
|
const CatBoostWrapperAPI * api;
|
2017-10-06 14:48:33 +00:00
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
std::unique_ptr<ICatBoostModel> model;
|
|
|
|
|
|
|
|
size_t float_features_count;
|
|
|
|
size_t cat_features_count;
|
|
|
|
|
|
|
|
void init(const std::string & lib_path);
|
2017-10-06 14:48:33 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
}
|