ClickHouse/programs/library-bridge/CatBoostLibraryHandler.h
Robert Schulze fac1be9700
chore: restore SYSTEM RELOAD MODEL(S) and moniting view SYSTEM.MODELS
- This commit restores statements "SYSTEM RELOAD MODEL(S)" which provide
  a mechanism to update a model explicitly. It also saves potentially
  unnecessary reloads of a model from disk after it's initial load.

  To keep the complexity low, the semantics of "SYSTEM RELOAD MODEL(S)
  was changed from eager to lazy. This means that both statements
  previously immedately reloaded the specified/all models, whereas now
  the statements only trigger an unload and the first call to
  catboostEvaluate() does the actual load.

- Monitoring view SYSTEM.MODELS is also restored but with some obsolete
  fields removed. The view was not documented in the past and for now it
  remains undocumented. The commit is thus not considered a breach of
  ClickHouse's public interface.
2022-09-12 19:33:02 +00:00

79 lines
2.6 KiB
C++

#pragma once
#include "CatBoostLibraryAPI.h"
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <Common/SharedLibrary.h>
#include <base/defines.h>
#include <chrono>
#include <mutex>
namespace DB
{
/// Abstracts access to the CatBoost shared library.
class CatBoostLibraryHandler
{
public:
/// Holds pointers to CatBoost library functions
struct APIHolder
{
explicit APIHolder(SharedLibrary & lib);
// NOLINTBEGIN(readability-identifier-naming)
CatBoostLibraryAPI::ModelCalcerCreateFunc ModelCalcerCreate;
CatBoostLibraryAPI::ModelCalcerDeleteFunc ModelCalcerDelete;
CatBoostLibraryAPI::GetErrorStringFunc GetErrorString;
CatBoostLibraryAPI::LoadFullModelFromFileFunc LoadFullModelFromFile;
CatBoostLibraryAPI::CalcModelPredictionFlatFunc CalcModelPredictionFlat;
CatBoostLibraryAPI::CalcModelPredictionFunc CalcModelPrediction;
CatBoostLibraryAPI::CalcModelPredictionWithHashedCatFeaturesFunc CalcModelPredictionWithHashedCatFeatures;
CatBoostLibraryAPI::GetStringCatFeatureHashFunc GetStringCatFeatureHash;
CatBoostLibraryAPI::GetIntegerCatFeatureHashFunc GetIntegerCatFeatureHash;
CatBoostLibraryAPI::GetFloatFeaturesCountFunc GetFloatFeaturesCount;
CatBoostLibraryAPI::GetCatFeaturesCountFunc GetCatFeaturesCount;
CatBoostLibraryAPI::GetTreeCountFunc GetTreeCount;
CatBoostLibraryAPI::GetDimensionsCountFunc GetDimensionsCount;
// NOLINTEND(readability-identifier-naming)
};
CatBoostLibraryHandler(
const String & library_path,
const String & model_path);
~CatBoostLibraryHandler();
std::chrono::system_clock::time_point getLoadingStartTime() const;
std::chrono::milliseconds getLoadingDuration() const;
size_t getTreeCount() const;
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
private:
std::chrono::system_clock::time_point loading_start_time;
std::chrono::milliseconds loading_duration;
const SharedLibraryPtr library;
const APIHolder api;
mutable std::mutex mutex;
CatBoostLibraryAPI::ModelCalcerHandle * model_calcer_handle TSA_GUARDED_BY(mutex) TSA_PT_GUARDED_BY(mutex);
size_t float_features_count TSA_GUARDED_BY(mutex);
size_t cat_features_count TSA_GUARDED_BY(mutex);
size_t tree_count TSA_GUARDED_BY(mutex);
ColumnFloat64::MutablePtr evalImpl(const ColumnRawPtrs & columns, bool cat_features_are_strings) const TSA_REQUIRES(mutex);
};
using CatBoostLibraryHandlerPtr = std::shared_ptr<CatBoostLibraryHandler>;
}