ClickHouse/dbms/src/Interpreters/CatBoostModel.h

89 lines
2.3 KiB
C++
Raw Normal View History

2017-10-06 14:48:33 +00:00
#pragma once
#include <Interpreters/IExternalLoadable.h>
#include <Columns/IColumn.h>
#include <Columns/ColumnsNumber.h>
namespace DB
{
2017-10-20 15:12:34 +00:00
/// CatBoost wrapper interface functions.
2017-10-26 12:18:37 +00:00
struct CatBoostWrapperAPI;
class CatBoostWrapperAPIProvider
2017-10-06 18:05:30 +00:00
{
public:
2017-10-26 12:18:37 +00:00
virtual ~CatBoostWrapperAPIProvider() = default;
virtual const CatBoostWrapperAPI & getAPI() const = 0;
2017-10-06 18:05:30 +00:00
};
2017-10-20 15:12:34 +00:00
/// CatBoost model interface.
class ICatBoostModel
{
public:
virtual ~ICatBoostModel() = default;
2017-10-20 15:12:34 +00:00
/// Evaluate model. Use first `float_features_count` columns as float features,
/// the others `cat_features_count` as categorical features.
virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0;
virtual size_t getFloatFeaturesCount() const = 0;
virtual size_t getCatFeaturesCount() const = 0;
virtual size_t getTreeCount() const = 0;
};
2017-10-06 18:05:30 +00:00
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
2017-10-20 15:12:34 +00:00
/// General ML model evaluator interface.
class IModel : public IExternalLoadable
2017-10-06 14:48:33 +00:00
{
public:
virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0;
2017-12-04 12:15:21 +00:00
virtual std::string getTypeName() const = 0;
virtual DataTypePtr getReturnType() const = 0;
};
class CatBoostModel : public IModel
{
public:
CatBoostModel(std::string name, std::string model_path,
std::string lib_path, const ExternalLoadableLifetime & lifetime);
ColumnPtr evaluate(const ColumnRawPtrs & columns) const override;
2017-12-04 12:15:21 +00:00
std::string getTypeName() const override { return "catboost"; }
2017-10-20 15:12:34 +00:00
size_t getFloatFeaturesCount() const;
size_t getCatFeaturesCount() const;
size_t getTreeCount() const;
DataTypePtr getReturnType() const override;
2017-10-20 15:12:34 +00:00
/// IExternalLoadable interface.
2017-10-06 14:48:33 +00:00
const ExternalLoadableLifetime & getLifetime() const override;
std::string getName() const override { return name; }
bool supportUpdates() const override { return true; }
bool isModified() const override;
std::shared_ptr<const IExternalLoadable> clone() const override;
2017-10-06 14:48:33 +00:00
private:
std::string name;
2017-10-06 18:05:30 +00:00
std::string model_path;
std::string lib_path;
2017-10-06 18:05:30 +00:00
ExternalLoadableLifetime lifetime;
2017-10-26 12:18:37 +00:00
std::shared_ptr<CatBoostWrapperAPIProvider> api_provider;
const CatBoostWrapperAPI * api;
2017-10-06 14:48:33 +00:00
std::unique_ptr<ICatBoostModel> model;
size_t float_features_count;
size_t cat_features_count;
size_t tree_count;
void init();
2017-10-06 14:48:33 +00:00
};
}