chore: restore SYSTEM RELOAD MODEL(S) and moniting view SYSTEM.MODELS

- This commit restores statements "SYSTEM RELOAD MODEL(S)" which provide
  a mechanism to update a model explicitly. It also saves potentially
  unnecessary reloads of a model from disk after it's initial load.

  To keep the complexity low, the semantics of "SYSTEM RELOAD MODEL(S)
  was changed from eager to lazy. This means that both statements
  previously immedately reloaded the specified/all models, whereas now
  the statements only trigger an unload and the first call to
  catboostEvaluate() does the actual load.

- Monitoring view SYSTEM.MODELS is also restored but with some obsolete
  fields removed. The view was not documented in the past and for now it
  remains undocumented. The commit is thus not considered a breach of
  ClickHouse's public interface.
This commit is contained in:
Robert Schulze 2022-09-11 15:50:36 +00:00
parent c16707ff00
commit fac1be9700
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
16 changed files with 459 additions and 64 deletions

View File

@ -11,6 +11,8 @@ The list of available `SYSTEM` statements:
- [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries) - [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries)
- [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries) - [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries)
- [RELOAD DICTIONARY](#query_language-system-reload-dictionary) - [RELOAD DICTIONARY](#query_language-system-reload-dictionary)
- [RELOAD MODELS](#query_language-system-reload-models)
- [RELOAD MODEL](#query_language-system-reload-model)
- [RELOAD FUNCTIONS](#query_language-system-reload-functions) - [RELOAD FUNCTIONS](#query_language-system-reload-functions)
- [RELOAD FUNCTION](#query_language-system-reload-functions) - [RELOAD FUNCTION](#query_language-system-reload-functions)
- [DROP DNS CACHE](#query_language-system-drop-dns-cache) - [DROP DNS CACHE](#query_language-system-drop-dns-cache)
@ -65,6 +67,26 @@ The status of the dictionary can be checked by querying the `system.dictionaries
SELECT name, status FROM system.dictionaries; SELECT name, status FROM system.dictionaries;
``` ```
## RELOAD MODELS
Reloads all [CatBoost](../../guides/developer/apply-catboost-model.md) models.
**Syntax**
```sql
SYSTEM RELOAD MODELS [ON CLUSTER cluster_name]
```
## RELOAD MODEL
Completely reloads a CatBoost model `model_path`.
**Syntax**
```sql
SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] <model_path>
```
## RELOAD FUNCTIONS ## RELOAD FUNCTIONS
Reloads all registered [executable user defined functions](../functions/index.md#executable-user-defined-functions) or one of them from a configuration file. Reloads all registered [executable user defined functions](../functions/index.md#executable-user-defined-functions) or one of them from a configuration file.

View File

@ -9,6 +9,8 @@ sidebar_label: SYSTEM
- [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries) - [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries)
- [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries) - [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries)
- [RELOAD DICTIONARY](#query_language-system-reload-dictionary) - [RELOAD DICTIONARY](#query_language-system-reload-dictionary)
- [RELOAD MODELS](#query_language-system-reload-models)
- [RELOAD MODEL](#query_language-system-reload-model)
- [RELOAD FUNCTIONS](#query_language-system-reload-functions) - [RELOAD FUNCTIONS](#query_language-system-reload-functions)
- [RELOAD FUNCTION](#query_language-system-reload-functions) - [RELOAD FUNCTION](#query_language-system-reload-functions)
- [DROP DNS CACHE](#query_language-system-drop-dns-cache) - [DROP DNS CACHE](#query_language-system-drop-dns-cache)
@ -62,6 +64,26 @@ sidebar_label: SYSTEM
SELECT name, status FROM system.dictionaries; SELECT name, status FROM system.dictionaries;
``` ```
## RELOAD MODELS {#query_language-system-reload-models}
Перегружает все модели [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse).
**Синтаксис**
```sql
SYSTEM RELOAD MODELS
```
## RELOAD MODEL {#query_language-system-reload-model}
Полностью перегружает модель [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse) `model_path`.
**Синтаксис**
```sql
SYSTEM RELOAD MODEL <model_path>
```
## RELOAD FUNCTIONS {#query_language-system-reload-functions} ## RELOAD FUNCTIONS {#query_language-system-reload-functions}
Перезагружает все зарегистрированные [исполняемые пользовательские функции](../functions/index.md#executable-user-defined-functions) или одну из них из файла конфигурации. Перезагружает все зарегистрированные [исполняемые пользовательские функции](../functions/index.md#executable-user-defined-functions) или одну из них из файла конфигурации.

View File

@ -34,7 +34,8 @@ CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib)
CatBoostLibraryHandler::CatBoostLibraryHandler( CatBoostLibraryHandler::CatBoostLibraryHandler(
const std::string & library_path, const std::string & library_path,
const std::string & model_path) const std::string & model_path)
: library(std::make_shared<SharedLibrary>(library_path)) : loading_start_time(std::chrono::system_clock::now())
, library(std::make_shared<SharedLibrary>(library_path))
, api(*library) , api(*library)
{ {
model_calcer_handle = api.ModelCalcerCreate(); model_calcer_handle = api.ModelCalcerCreate();
@ -51,6 +52,8 @@ CatBoostLibraryHandler::CatBoostLibraryHandler(
tree_count = 1; tree_count = 1;
if (api.GetDimensionsCount) if (api.GetDimensionsCount)
tree_count = api.GetDimensionsCount(model_calcer_handle); tree_count = api.GetDimensionsCount(model_calcer_handle);
loading_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - loading_start_time);
} }
CatBoostLibraryHandler::~CatBoostLibraryHandler() CatBoostLibraryHandler::~CatBoostLibraryHandler()
@ -58,6 +61,16 @@ CatBoostLibraryHandler::~CatBoostLibraryHandler()
api.ModelCalcerDelete(model_calcer_handle); api.ModelCalcerDelete(model_calcer_handle);
} }
std::chrono::system_clock::time_point CatBoostLibraryHandler::getLoadingStartTime() const
{
return loading_start_time;
}
std::chrono::milliseconds CatBoostLibraryHandler::getLoadingDuration() const
{
return loading_duration;
}
namespace namespace
{ {

View File

@ -10,6 +10,7 @@
#include <Common/SharedLibrary.h> #include <Common/SharedLibrary.h>
#include <base/defines.h> #include <base/defines.h>
#include <chrono>
#include <mutex> #include <mutex>
namespace DB namespace DB
@ -42,16 +43,22 @@ public:
}; };
CatBoostLibraryHandler( CatBoostLibraryHandler(
const std::string & library_path, const String & library_path,
const std::string & model_path); const String & model_path);
~CatBoostLibraryHandler(); ~CatBoostLibraryHandler();
std::chrono::system_clock::time_point getLoadingStartTime() const;
std::chrono::milliseconds getLoadingDuration() const;
size_t getTreeCount() const; size_t getTreeCount() const;
ColumnPtr evaluate(const ColumnRawPtrs & columns) const; ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
private: private:
std::chrono::system_clock::time_point loading_start_time;
std::chrono::milliseconds loading_duration;
const SharedLibraryPtr library; const SharedLibraryPtr library;
const APIHolder api; const APIHolder api;

View File

@ -2,6 +2,7 @@
#include <Common/logger_useful.h> #include <Common/logger_useful.h>
namespace DB namespace DB
{ {
@ -11,39 +12,64 @@ CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance()
return instance; return instance;
} }
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::get(const String & model_path) CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::getOrCreateModel(const String & model_path, const String & library_path, bool create_if_not_found)
{ {
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
if (auto handler = library_handlers.find(model_path); handler != library_handlers.end()) auto handler = library_handlers.find(model_path);
bool found = (handler != library_handlers.end());
if (found)
return handler->second; return handler->second;
return nullptr; else
}
void CatBoostLibraryHandlerFactory::create(const String & library_path, const String & model_path)
{
std::lock_guard lock(mutex);
if (library_handlers.contains(model_path))
{ {
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot load catboost library handler for model path {} because it exists already", model_path); if (create_if_not_found)
return; {
auto new_handler = std::make_shared<CatBoostLibraryHandler>(library_path, model_path);
library_handlers.emplace(std::make_pair(model_path, new_handler));
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path '{}'", model_path);
return new_handler;
}
return nullptr;
} }
library_handlers.emplace(std::make_pair(model_path, std::make_shared<CatBoostLibraryHandler>(library_path, model_path)));
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path {}.", model_path);
} }
void CatBoostLibraryHandlerFactory::remove(const String & model_path) void CatBoostLibraryHandlerFactory::removeModel(const String & model_path)
{ {
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
bool deleted = library_handlers.erase(model_path); bool deleted = library_handlers.erase(model_path);
if (!deleted) if (!deleted)
{ {
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path: {}", model_path); LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path '{}'", model_path);
return; return;
} }
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path: {}", model_path); LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path '{}'", model_path);
}
void CatBoostLibraryHandlerFactory::removeAllModels()
{
std::lock_guard lock(mutex);
library_handlers.clear();
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded all catboost library handlers");
}
ExternalModelInfos CatBoostLibraryHandlerFactory::getModelInfos()
{
std::lock_guard lock(mutex);
ExternalModelInfos result;
for (const auto & handler : library_handlers)
result.push_back({
.model_path = handler.first,
.model_type = "catboost",
.loading_start_time = handler.second->getLoadingStartTime(),
.loading_duration = handler.second->getLoadingDuration()
});
return result;
} }
} }

View File

@ -3,7 +3,9 @@
#include "CatBoostLibraryHandler.h" #include "CatBoostLibraryHandler.h"
#include <base/defines.h> #include <base/defines.h>
#include <Common/ExternalModelInfo.h>
#include <chrono>
#include <mutex> #include <mutex>
#include <unordered_map> #include <unordered_map>
@ -16,14 +18,15 @@ class CatBoostLibraryHandlerFactory final : private boost::noncopyable
public: public:
static CatBoostLibraryHandlerFactory & instance(); static CatBoostLibraryHandlerFactory & instance();
CatBoostLibraryHandlerPtr get(const String & model_path); CatBoostLibraryHandlerPtr getOrCreateModel(const String & model_path, const String & library_path, bool create_if_not_found);
void create(const String & library_path, const String & model_path); void removeModel(const String & model_path);
void removeAllModels();
void remove(const String & model_path); ExternalModelInfos getModelInfos();
private: private:
/// map: model path -> shared library handler /// map: model path --> catboost library handler
std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex); std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
std::mutex mutex; std::mutex mutex;
}; };

View File

@ -26,6 +26,7 @@
#include <Formats/NativeReader.h> #include <Formats/NativeReader.h>
#include <Formats/NativeWriter.h> #include <Formats/NativeWriter.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeString.h>
namespace DB namespace DB
@ -422,8 +423,6 @@ CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
{ {
std::lock_guard lock(mutex);
LOG_TRACE(log, "Request URI: {}", request.getURI()); LOG_TRACE(log, "Request URI: {}", request.getURI());
HTMLForm params(getContext()->getSettingsRef(), request); HTMLForm params(getContext()->getSettingsRef(), request);
@ -460,7 +459,51 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
try try
{ {
if (method == "catboost_GetTreeCount") if (method == "catboost_list")
{
ExternalModelInfos model_infos = CatBoostLibraryHandlerFactory::instance().getModelInfos();
writeIntBinary(static_cast<UInt64>(model_infos.size()), out);
for (const auto & info : model_infos)
{
writeStringBinary(info.model_path, out);
writeStringBinary(info.model_type, out);
UInt64 t = std::chrono::system_clock::to_time_t(info.loading_start_time);
writeIntBinary(t, out);
t = info.loading_duration.count();
writeIntBinary(t, out);
}
}
else if (method == "catboost_removeModel")
{
auto & read_buf = request.getStream();
params.read(read_buf);
if (!params.has("model_path"))
{
processError(response, "No 'model_path' in request URL");
return;
}
const String & model_path = params.get("model_path");
CatBoostLibraryHandlerFactory::instance().removeModel(model_path);
String res = "1";
writeStringBinary(res, out);
}
else if (method == "catboost_removeAllModels")
{
CatBoostLibraryHandlerFactory::instance().removeAllModels();
String res = "1";
writeStringBinary(res, out);
}
else if (method == "catboost_GetTreeCount")
{ {
auto & read_buf = request.getStream(); auto & read_buf = request.getStream();
params.read(read_buf); params.read(read_buf);
@ -481,18 +524,7 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
const String & model_path = params.get("model_path"); const String & model_path = params.get("model_path");
CatBoostLibraryHandlerFactory::instance().remove(model_path); auto catboost_handler = CatBoostLibraryHandlerFactory::instance().getOrCreateModel(model_path, library_path, /*create_if_not_found*/ true);
CatBoostLibraryHandlerFactory::instance().create(library_path, model_path);
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
if (!catboost_handler)
{
processError(response, "CatBoost library is not loaded for model " + model_path);
return;
}
size_t tree_count = catboost_handler->getTreeCount(); size_t tree_count = catboost_handler->getTreeCount();
writeIntBinary(tree_count, out); writeIntBinary(tree_count, out);
} }
@ -526,11 +558,11 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
for (const auto & p : col_ptrs) for (const auto & p : col_ptrs)
col_raw_ptrs.push_back(&*p); col_raw_ptrs.push_back(&*p);
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path); auto catboost_handler = CatBoostLibraryHandlerFactory::instance().getOrCreateModel(model_path, "DummyLibraryPath", /*create_if_not_found*/ false);
if (!catboost_handler) if (!catboost_handler)
{ {
processError(response, "CatBoost library is not loaded for model" + model_path); processError(response, "CatBoost library is not loaded for model '" + model_path + "'. Please try again.");
return; return;
} }

View File

@ -3,7 +3,6 @@
#include <Common/logger_useful.h> #include <Common/logger_useful.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Server/HTTP/HTTPRequestHandler.h> #include <Server/HTTP/HTTPRequestHandler.h>
#include <mutex>
namespace DB namespace DB
@ -46,16 +45,21 @@ private:
/// Handler for requests to catboost library. The call protocol is as follows: /// Handler for requests to catboost library. The call protocol is as follows:
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and /// (1) Send a "catboost_GetTreeCount" request from the server to the bridge. It contains a library path (e.g /home/user/libcatboost.so) and
/// a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was /// a model path (e.g. /home/user/model.bin). This loads the catboost library handler associated with the model path, then executes
/// loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler /// GetTreeCount() on the library handler and sends the result back to the server.
/// and finally sends the result back to the server. /// (2) Send "catboost_Evaluate" from the server to the bridge. It contains a model path and the features to run the interference on. Step
/// Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in /// (2) is called multiple times (once per chunk) by the server.
/// the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than ///
/// once and if "model.bin" was updated in between. /// We would ideally like to have steps (1) and (2) in one atomic handler but can't because the evaluation on the server side is divided
/// (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on. /// into two dependent phases: FunctionCatBoostEvaluate::getReturnTypeImpl() and ::executeImpl(). So the model may in principle be unloaded
/// Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library /// from the library-bridge between steps (1) and (2). Step (2) checks if that is the case and fails gracefully. This is okay because that
/// handler for the given model path is expected to be already loaded by Step (1). /// situation considered exceptional and rare.
///
/// An update of a model is performed by unloading it. The first call to "catboost_GetTreeCount" brings it into memory again.
///
/// Further handlers are provided for unloading a specific model, for unloading all models or for retrieving information about the loaded
/// models for display in a system view.
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{ {
public: public:
@ -64,7 +68,6 @@ public:
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
private: private:
std::mutex mutex;
const size_t keep_alive_timeout; const size_t keep_alive_timeout;
Poco::Logger * log; Poco::Logger * log;
}; };

View File

@ -10,6 +10,8 @@
#include <IO/WriteBufferFromString.h> #include <IO/WriteBufferFromString.h>
#include <Poco/Net/HTTPRequest.h> #include <Poco/Net/HTTPRequest.h>
#include <random>
namespace DB namespace DB
{ {
@ -28,6 +30,16 @@ CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(
{ {
} }
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view model_path_)
: CatBoostLibraryBridgeHelper(context_, "", model_path_)
{
}
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(ContextPtr context_)
: CatBoostLibraryBridgeHelper(context_, "", "")
{
}
Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const
{ {
auto uri = createBaseURI(); auto uri = createBaseURI();
@ -71,6 +83,74 @@ bool CatBoostLibraryBridgeHelper::bridgeHandShake()
return true; return true;
} }
ExternalModelInfos CatBoostLibraryBridgeHelper::listModels()
{
startBridgeSync();
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_LIST_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[](std::ostream &) {},
http_timeouts, credentials);
ExternalModelInfos result;
UInt64 num_rows;
readIntBinary(num_rows, buf);
for (UInt64 i = 0; i < num_rows; ++i)
{
ExternalModelInfo info;
readStringBinary(info.model_path, buf);
readStringBinary(info.model_type, buf);
UInt64 t;
readIntBinary(t, buf);
info.loading_start_time = std::chrono::system_clock::from_time_t(t);
readIntBinary(t, buf);
info.loading_duration = std::chrono::milliseconds(t);
result.push_back(info);
}
return result;
}
void CatBoostLibraryBridgeHelper::removeModel()
{
startBridgeSync();
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_REMOVEMODEL_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[this](std::ostream & os)
{
os << "model_path=" << escapeForFileName(model_path);
},
http_timeouts, credentials);
String result;
readStringBinary(result, buf);
assert(result == "1");
}
void CatBoostLibraryBridgeHelper::removeAllModels()
{
startBridgeSync();
ReadWriteBufferFromHTTP buf(
createRequestURI(CATBOOST_REMOVEALLMODELS_METHOD),
Poco::Net::HTTPRequest::HTTP_POST,
[](std::ostream &){},
http_timeouts, credentials);
String result;
readStringBinary(result, buf);
assert(result == "1");
}
size_t CatBoostLibraryBridgeHelper::getTreeCount() size_t CatBoostLibraryBridgeHelper::getTreeCount()
{ {
startBridgeSync(); startBridgeSync();
@ -85,9 +165,9 @@ size_t CatBoostLibraryBridgeHelper::getTreeCount()
}, },
http_timeouts, credentials); http_timeouts, credentials);
size_t res; size_t result;
readIntBinary(res, buf); readIntBinary(result, buf);
return res; return result;
} }
ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns) ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns)

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <BridgeHelper/LibraryBridgeHelper.h> #include <BridgeHelper/LibraryBridgeHelper.h>
#include <Common/ExternalModelInfo.h>
#include <DataTypes/IDataType.h> #include <DataTypes/IDataType.h>
#include <IO/ReadWriteBufferFromHTTP.h> #include <IO/ReadWriteBufferFromHTTP.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
@ -17,9 +18,15 @@ public:
static constexpr inline auto MAIN_HANDLER = "/catboost_request"; static constexpr inline auto MAIN_HANDLER = "/catboost_request";
CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view library_path_, std::string_view model_path_); CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view library_path_, std::string_view model_path_);
CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view model_path_);
explicit CatBoostLibraryBridgeHelper(ContextPtr context_);
ExternalModelInfos listModels();
void removeModel();
void removeAllModels();
size_t getTreeCount(); size_t getTreeCount();
ColumnPtr evaluate(const ColumnsWithTypeAndName & columns); ColumnPtr evaluate(const ColumnsWithTypeAndName & columns);
protected: protected:
@ -30,6 +37,9 @@ protected:
bool bridgeHandShake() override; bool bridgeHandShake() override;
private: private:
static constexpr inline auto CATBOOST_LIST_METHOD = "catboost_list";
static constexpr inline auto CATBOOST_REMOVEMODEL_METHOD = "catboost_removeModel";
static constexpr inline auto CATBOOST_REMOVEALLMODELS_METHOD = "catboost_removeAllModels";
static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount"; static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount";
static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate"; static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate";

View File

@ -0,0 +1,20 @@
#pragma once
#include <vector>
#include <base/types.h>
namespace DB
{
/// Details about external machine learning model, used by clickhouse-server and clickhouse-library-bridge
struct ExternalModelInfo
{
String model_path;
String model_type;
std::chrono::system_clock::time_point loading_start_time; /// serialized as std::time_t
std::chrono::milliseconds loading_duration; /// serialized as UInt64
};
using ExternalModelInfos = std::vector<ExternalModelInfo>;
}

View File

@ -35,6 +35,7 @@
#include <Interpreters/ProcessorsProfileLog.h> #include <Interpreters/ProcessorsProfileLog.h>
#include <Interpreters/JIT/CompiledExpressionCache.h> #include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Interpreters/TransactionLog.h> #include <Interpreters/TransactionLog.h>
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
#include <Access/ContextAccess.h> #include <Access/ContextAccess.h>
#include <Access/Common/AllowedClientHosts.h> #include <Access/Common/AllowedClientHosts.h>
#include <Databases/IDatabase.h> #include <Databases/IDatabase.h>
@ -373,10 +374,17 @@ BlockIO InterpreterSystemQuery::execute()
break; break;
} }
case Type::RELOAD_MODEL: case Type::RELOAD_MODEL:
{
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), query.target_model);
bridge_helper->removeModel();
break;
}
case Type::RELOAD_MODELS: case Type::RELOAD_MODELS:
{ {
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL); getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
/// SYSTEM RELOAD MODEL(S) is no longer supported. For compat reasons, it is not completely removed but retained as no-op. auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext());
bridge_helper->removeAllModels();
break; break;
} }
case Type::RELOAD_FUNCTION: case Type::RELOAD_FUNCTION:

View File

@ -0,0 +1,38 @@
#include <Storages/System/StorageSystemModels.h>
#include <Common/ExternalModelInfo.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeEnum.h>
#include <Interpreters/Context.h>
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
namespace DB
{
NamesAndTypesList StorageSystemModels::getNamesAndTypes()
{
return {
{ "model_path", std::make_shared<DataTypeString>() },
{ "type", std::make_shared<DataTypeString>() },
{ "loading_start_time", std::make_shared<DataTypeDateTime>() },
{ "loading_duration", std::make_shared<DataTypeFloat32>() },
};
}
void StorageSystemModels::fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo &) const
{
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(context);
ExternalModelInfos infos = bridge_helper->listModels();
for (const auto & info : infos)
{
res_columns[0]->insert(info.model_path);
res_columns[1]->insert(info.model_type);
res_columns[2]->insert(static_cast<UInt64>(std::chrono::system_clock::to_time_t(info.loading_start_time)));
res_columns[3]->insert(std::chrono::duration_cast<std::chrono::duration<float>>(info.loading_duration).count());
}
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include <Storages/System/IStorageSystemOneBlock.h>
namespace DB
{
class Context;
class StorageSystemModels final : public IStorageSystemOneBlock<StorageSystemModels>
{
public:
std::string getName() const override { return "SystemModels"; }
static NamesAndTypesList getNamesAndTypes();
protected:
using IStorageSystemOneBlock::IStorageSystemOneBlock;
void fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo & query_info) const override;
};
}

View File

@ -25,6 +25,7 @@
#include <Storages/System/StorageSystemMerges.h> #include <Storages/System/StorageSystemMerges.h>
#include <Storages/System/StorageSystemReplicatedFetches.h> #include <Storages/System/StorageSystemReplicatedFetches.h>
#include <Storages/System/StorageSystemMetrics.h> #include <Storages/System/StorageSystemMetrics.h>
#include <Storages/System/StorageSystemModels.h>
#include <Storages/System/StorageSystemMutations.h> #include <Storages/System/StorageSystemMutations.h>
#include <Storages/System/StorageSystemNumbers.h> #include <Storages/System/StorageSystemNumbers.h>
#include <Storages/System/StorageSystemOne.h> #include <Storages/System/StorageSystemOne.h>
@ -163,6 +164,7 @@ void attachSystemTablesServer(ContextPtr context, IDatabase & system_database, b
attach<StorageSystemDDLWorkerQueue>(context, system_database, "distributed_ddl_queue"); attach<StorageSystemDDLWorkerQueue>(context, system_database, "distributed_ddl_queue");
attach<StorageSystemDistributionQueue>(context, system_database, "distribution_queue"); attach<StorageSystemDistributionQueue>(context, system_database, "distribution_queue");
attach<StorageSystemDictionaries>(context, system_database, "dictionaries"); attach<StorageSystemDictionaries>(context, system_database, "dictionaries");
attach<StorageSystemModels>(context, system_database, "models");
attach<StorageSystemClusters>(context, system_database, "clusters"); attach<StorageSystemClusters>(context, system_database, "clusters");
attach<StorageSystemGraphite>(context, system_database, "graphite_retentions"); attach<StorageSystemGraphite>(context, system_database, "graphite_retentions");
attach<StorageSystemMacros>(context, system_database, "macros"); attach<StorageSystemMacros>(context, system_database, "macros");

View File

@ -44,6 +44,8 @@ def testConstantFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query( result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
) )
@ -55,6 +57,8 @@ def testNonConstantFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
instance.query("DROP TABLE IF EXISTS T;") instance.query("DROP TABLE IF EXISTS T;")
instance.query( instance.query(
"CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;" "CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;"
@ -74,6 +78,8 @@ def testModelPathIsNotAConstString(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error( err = instance.query_and_get_error(
"select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);" "select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
) )
@ -98,6 +104,8 @@ def testWrongNumberOfFeatureArguments(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error( err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');"
) )
@ -116,6 +124,8 @@ def testFloatFeatureMustBeNumeric(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error( err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);"
) )
@ -126,6 +136,8 @@ def testCategoricalFeatureMustBeNumericOrString(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error( err = instance.query_and_get_error(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);"
) )
@ -136,6 +148,8 @@ def testOnLowCardinalityFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
# same but on domain-compressed data # same but on domain-compressed data
result = instance.query( result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));"
@ -148,6 +162,8 @@ def testOnNullableFeatures(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query( result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));"
) )
@ -165,6 +181,8 @@ def testInvalidLibraryPath(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
# temporarily move library elsewhere # temporarily move library elsewhere
instance.exec_in_container( instance.exec_in_container(
[ [
@ -196,6 +214,8 @@ def testInvalidModelPath(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
err = instance.query_and_get_error( err = instance.query_and_get_error(
"select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);" "select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
) )
@ -211,6 +231,8 @@ def testRecoveryAfterCrash(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query( result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
) )
@ -235,6 +257,8 @@ def testAmazonModelSingleRow(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query( result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);" "select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
) )
@ -246,6 +270,8 @@ def testAmazonModelManyRows(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
result = instance.query("drop table if exists amazon") result = instance.query("drop table if exists amazon")
result = instance.query( result = instance.query(
@ -272,6 +298,8 @@ def testModelUpdate(ch_cluster):
if instance.is_built_with_memory_sanitizer(): if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
result = instance.query(query) result = instance.query(query)
@ -294,13 +322,18 @@ def testModelUpdate(ch_cluster):
] ]
) )
# since the amazon model has a different number of features than the simple model, we should get an error # unload simple model
err = instance.query_and_get_error(query) result = instance.query(
assert ( "system reload model '/etc/clickhouse-server/model/simple_model.bin'"
"Number of columns is different with number of features: columns size 11 float features size 0 + cat features size 9"
in err
) )
# load the simple-model-camouflaged amazon model
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
)
expected = "0.7774665009089274\n"
assert result == expected
# restore # restore
instance.exec_in_container( instance.exec_in_container(
[ [
@ -316,3 +349,54 @@ def testModelUpdate(ch_cluster):
"mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin", "mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin",
] ]
) )
def testSystemModelsAndModelRefresh(ch_cluster):
if instance.is_built_with_memory_sanitizer():
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
result = instance.query("system reload models")
# check model system view
result = instance.query("select * from system.models")
expected = ""
assert result == expected
# load simple model
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
)
expected = "-1.930268705869267\n"
assert result == expected
# check model system view with one model loaded
result = instance.query("select * from system.models")
assert result.count("\n") == 1
expected = "/etc/clickhouse-server/model/simple_model.bin"
assert expected in result
# load amazon model
result = instance.query(
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
)
expected = "0.7774665009089274\n"
assert result == expected
# check model system view with one model loaded
result = instance.query("select * from system.models")
assert result.count("\n") == 2
expected = "/etc/clickhouse-server/model/simple_model.bin"
assert expected in result
expected = "/etc/clickhouse-server/model/amazon_model.bin"
assert expected in result
# unload simple model
result = instance.query(
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
)
# check model system view, it should not display the removed model
result = instance.query("select * from system.models")
assert result.count("\n") == 1
expected = "/etc/clickhouse-server/model/amazon_model.bin"
assert expected in result