From fac1be9700956c03582ed759b4978b40988e2da9 Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Sun, 11 Sep 2022 15:50:36 +0000 Subject: [PATCH] chore: restore SYSTEM RELOAD MODEL(S) and moniting view SYSTEM.MODELS - This commit restores statements "SYSTEM RELOAD MODEL(S)" which provide a mechanism to update a model explicitly. It also saves potentially unnecessary reloads of a model from disk after it's initial load. To keep the complexity low, the semantics of "SYSTEM RELOAD MODEL(S) was changed from eager to lazy. This means that both statements previously immedately reloaded the specified/all models, whereas now the statements only trigger an unload and the first call to catboostEvaluate() does the actual load. - Monitoring view SYSTEM.MODELS is also restored but with some obsolete fields removed. The view was not documented in the past and for now it remains undocumented. The commit is thus not considered a breach of ClickHouse's public interface. --- docs/en/sql-reference/statements/system.md | 22 +++++ docs/ru/sql-reference/statements/system.md | 22 +++++ .../library-bridge/CatBoostLibraryHandler.cpp | 15 ++- .../library-bridge/CatBoostLibraryHandler.h | 11 ++- .../CatBoostLibraryHandlerFactory.cpp | 62 ++++++++---- .../CatBoostLibraryHandlerFactory.h | 11 ++- .../library-bridge/LibraryBridgeHandlers.cpp | 66 +++++++++---- .../library-bridge/LibraryBridgeHandlers.h | 27 +++--- .../CatBoostLibraryBridgeHelper.cpp | 86 ++++++++++++++++- .../CatBoostLibraryBridgeHelper.h | 12 ++- src/Common/ExternalModelInfo.h | 20 ++++ src/Interpreters/InterpreterSystemQuery.cpp | 10 +- src/Storages/System/StorageSystemModels.cpp | 38 ++++++++ src/Storages/System/StorageSystemModels.h | 25 +++++ src/Storages/System/attachSystemTables.cpp | 2 + .../test_catboost_evaluate/test.py | 94 ++++++++++++++++++- 16 files changed, 459 insertions(+), 64 deletions(-) create mode 100644 src/Common/ExternalModelInfo.h create mode 100644 src/Storages/System/StorageSystemModels.cpp create mode 100644 src/Storages/System/StorageSystemModels.h diff --git a/docs/en/sql-reference/statements/system.md b/docs/en/sql-reference/statements/system.md index 67eb94f3606..ea5bcf7197b 100644 --- a/docs/en/sql-reference/statements/system.md +++ b/docs/en/sql-reference/statements/system.md @@ -11,6 +11,8 @@ The list of available `SYSTEM` statements: - [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries) - [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries) - [RELOAD DICTIONARY](#query_language-system-reload-dictionary) +- [RELOAD MODELS](#query_language-system-reload-models) +- [RELOAD MODEL](#query_language-system-reload-model) - [RELOAD FUNCTIONS](#query_language-system-reload-functions) - [RELOAD FUNCTION](#query_language-system-reload-functions) - [DROP DNS CACHE](#query_language-system-drop-dns-cache) @@ -65,6 +67,26 @@ The status of the dictionary can be checked by querying the `system.dictionaries SELECT name, status FROM system.dictionaries; ``` +## RELOAD MODELS + +Reloads all [CatBoost](../../guides/developer/apply-catboost-model.md) models. + +**Syntax** + +```sql +SYSTEM RELOAD MODELS [ON CLUSTER cluster_name] +``` + +## RELOAD MODEL + +Completely reloads a CatBoost model `model_path`. + +**Syntax** + +```sql +SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] +``` + ## RELOAD FUNCTIONS Reloads all registered [executable user defined functions](../functions/index.md#executable-user-defined-functions) or one of them from a configuration file. diff --git a/docs/ru/sql-reference/statements/system.md b/docs/ru/sql-reference/statements/system.md index 01c9339e52d..5f3f1ad7d3c 100644 --- a/docs/ru/sql-reference/statements/system.md +++ b/docs/ru/sql-reference/statements/system.md @@ -9,6 +9,8 @@ sidebar_label: SYSTEM - [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries) - [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries) - [RELOAD DICTIONARY](#query_language-system-reload-dictionary) +- [RELOAD MODELS](#query_language-system-reload-models) +- [RELOAD MODEL](#query_language-system-reload-model) - [RELOAD FUNCTIONS](#query_language-system-reload-functions) - [RELOAD FUNCTION](#query_language-system-reload-functions) - [DROP DNS CACHE](#query_language-system-drop-dns-cache) @@ -62,6 +64,26 @@ sidebar_label: SYSTEM SELECT name, status FROM system.dictionaries; ``` +## RELOAD MODELS {#query_language-system-reload-models} + +Перегружает все модели [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse). + +**Синтаксис** + +```sql +SYSTEM RELOAD MODELS +``` + +## RELOAD MODEL {#query_language-system-reload-model} + +Полностью перегружает модель [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse) `model_path`. + +**Синтаксис** + +```sql +SYSTEM RELOAD MODEL +``` + ## RELOAD FUNCTIONS {#query_language-system-reload-functions} Перезагружает все зарегистрированные [исполняемые пользовательские функции](../functions/index.md#executable-user-defined-functions) или одну из них из файла конфигурации. diff --git a/programs/library-bridge/CatBoostLibraryHandler.cpp b/programs/library-bridge/CatBoostLibraryHandler.cpp index 18e56836f96..2c3ed583463 100644 --- a/programs/library-bridge/CatBoostLibraryHandler.cpp +++ b/programs/library-bridge/CatBoostLibraryHandler.cpp @@ -34,7 +34,8 @@ CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib) CatBoostLibraryHandler::CatBoostLibraryHandler( const std::string & library_path, const std::string & model_path) - : library(std::make_shared(library_path)) + : loading_start_time(std::chrono::system_clock::now()) + , library(std::make_shared(library_path)) , api(*library) { model_calcer_handle = api.ModelCalcerCreate(); @@ -51,6 +52,8 @@ CatBoostLibraryHandler::CatBoostLibraryHandler( tree_count = 1; if (api.GetDimensionsCount) tree_count = api.GetDimensionsCount(model_calcer_handle); + + loading_duration = std::chrono::duration_cast(std::chrono::system_clock::now() - loading_start_time); } CatBoostLibraryHandler::~CatBoostLibraryHandler() @@ -58,6 +61,16 @@ CatBoostLibraryHandler::~CatBoostLibraryHandler() api.ModelCalcerDelete(model_calcer_handle); } +std::chrono::system_clock::time_point CatBoostLibraryHandler::getLoadingStartTime() const +{ + return loading_start_time; +} + +std::chrono::milliseconds CatBoostLibraryHandler::getLoadingDuration() const +{ + return loading_duration; +} + namespace { diff --git a/programs/library-bridge/CatBoostLibraryHandler.h b/programs/library-bridge/CatBoostLibraryHandler.h index 306b16c4805..e0ff1d70250 100644 --- a/programs/library-bridge/CatBoostLibraryHandler.h +++ b/programs/library-bridge/CatBoostLibraryHandler.h @@ -10,6 +10,7 @@ #include #include +#include #include namespace DB @@ -42,16 +43,22 @@ public: }; CatBoostLibraryHandler( - const std::string & library_path, - const std::string & model_path); + const String & library_path, + const String & model_path); ~CatBoostLibraryHandler(); + std::chrono::system_clock::time_point getLoadingStartTime() const; + std::chrono::milliseconds getLoadingDuration() const; + size_t getTreeCount() const; ColumnPtr evaluate(const ColumnRawPtrs & columns) const; private: + std::chrono::system_clock::time_point loading_start_time; + std::chrono::milliseconds loading_duration; + const SharedLibraryPtr library; const APIHolder api; diff --git a/programs/library-bridge/CatBoostLibraryHandlerFactory.cpp b/programs/library-bridge/CatBoostLibraryHandlerFactory.cpp index 5dde7f84a95..841b4398f73 100644 --- a/programs/library-bridge/CatBoostLibraryHandlerFactory.cpp +++ b/programs/library-bridge/CatBoostLibraryHandlerFactory.cpp @@ -2,6 +2,7 @@ #include + namespace DB { @@ -11,39 +12,64 @@ CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance() return instance; } -CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::get(const String & model_path) +CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::getOrCreateModel(const String & model_path, const String & library_path, bool create_if_not_found) { std::lock_guard lock(mutex); - if (auto handler = library_handlers.find(model_path); handler != library_handlers.end()) + auto handler = library_handlers.find(model_path); + bool found = (handler != library_handlers.end()); + + if (found) return handler->second; - return nullptr; -} - -void CatBoostLibraryHandlerFactory::create(const String & library_path, const String & model_path) -{ - std::lock_guard lock(mutex); - - if (library_handlers.contains(model_path)) + else { - LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot load catboost library handler for model path {} because it exists already", model_path); - return; + if (create_if_not_found) + { + auto new_handler = std::make_shared(library_path, model_path); + library_handlers.emplace(std::make_pair(model_path, new_handler)); + LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path '{}'", model_path); + return new_handler; + } + return nullptr; } - - library_handlers.emplace(std::make_pair(model_path, std::make_shared(library_path, model_path))); - LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path {}.", model_path); } -void CatBoostLibraryHandlerFactory::remove(const String & model_path) +void CatBoostLibraryHandlerFactory::removeModel(const String & model_path) { std::lock_guard lock(mutex); + bool deleted = library_handlers.erase(model_path); if (!deleted) { - LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path: {}", model_path); + LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path '{}'", model_path); return; } - LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path: {}", model_path); + LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path '{}'", model_path); +} + +void CatBoostLibraryHandlerFactory::removeAllModels() +{ + std::lock_guard lock(mutex); + library_handlers.clear(); + LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded all catboost library handlers"); +} + +ExternalModelInfos CatBoostLibraryHandlerFactory::getModelInfos() +{ + std::lock_guard lock(mutex); + + ExternalModelInfos result; + + for (const auto & handler : library_handlers) + result.push_back({ + .model_path = handler.first, + .model_type = "catboost", + .loading_start_time = handler.second->getLoadingStartTime(), + .loading_duration = handler.second->getLoadingDuration() + }); + + return result; + } } diff --git a/programs/library-bridge/CatBoostLibraryHandlerFactory.h b/programs/library-bridge/CatBoostLibraryHandlerFactory.h index f17d19a3f06..4984d1713a8 100644 --- a/programs/library-bridge/CatBoostLibraryHandlerFactory.h +++ b/programs/library-bridge/CatBoostLibraryHandlerFactory.h @@ -3,7 +3,9 @@ #include "CatBoostLibraryHandler.h" #include +#include +#include #include #include @@ -16,14 +18,15 @@ class CatBoostLibraryHandlerFactory final : private boost::noncopyable public: static CatBoostLibraryHandlerFactory & instance(); - CatBoostLibraryHandlerPtr get(const String & model_path); + CatBoostLibraryHandlerPtr getOrCreateModel(const String & model_path, const String & library_path, bool create_if_not_found); - void create(const String & library_path, const String & model_path); + void removeModel(const String & model_path); + void removeAllModels(); - void remove(const String & model_path); + ExternalModelInfos getModelInfos(); private: - /// map: model path -> shared library handler + /// map: model path --> catboost library handler std::unordered_map library_handlers TSA_GUARDED_BY(mutex); std::mutex mutex; }; diff --git a/programs/library-bridge/LibraryBridgeHandlers.cpp b/programs/library-bridge/LibraryBridgeHandlers.cpp index 999cf9d4582..ccfc4c2b32b 100644 --- a/programs/library-bridge/LibraryBridgeHandlers.cpp +++ b/programs/library-bridge/LibraryBridgeHandlers.cpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace DB @@ -422,8 +423,6 @@ CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler( void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) { - std::lock_guard lock(mutex); - LOG_TRACE(log, "Request URI: {}", request.getURI()); HTMLForm params(getContext()->getSettingsRef(), request); @@ -460,7 +459,51 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ try { - if (method == "catboost_GetTreeCount") + if (method == "catboost_list") + { + ExternalModelInfos model_infos = CatBoostLibraryHandlerFactory::instance().getModelInfos(); + + writeIntBinary(static_cast(model_infos.size()), out); + + for (const auto & info : model_infos) + { + writeStringBinary(info.model_path, out); + writeStringBinary(info.model_type, out); + + UInt64 t = std::chrono::system_clock::to_time_t(info.loading_start_time); + writeIntBinary(t, out); + + t = info.loading_duration.count(); + writeIntBinary(t, out); + + } + } + else if (method == "catboost_removeModel") + { + auto & read_buf = request.getStream(); + params.read(read_buf); + + if (!params.has("model_path")) + { + processError(response, "No 'model_path' in request URL"); + return; + } + + const String & model_path = params.get("model_path"); + + CatBoostLibraryHandlerFactory::instance().removeModel(model_path); + + String res = "1"; + writeStringBinary(res, out); + } + else if (method == "catboost_removeAllModels") + { + CatBoostLibraryHandlerFactory::instance().removeAllModels(); + + String res = "1"; + writeStringBinary(res, out); + } + else if (method == "catboost_GetTreeCount") { auto & read_buf = request.getStream(); params.read(read_buf); @@ -481,18 +524,7 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ const String & model_path = params.get("model_path"); - CatBoostLibraryHandlerFactory::instance().remove(model_path); - - CatBoostLibraryHandlerFactory::instance().create(library_path, model_path); - - auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path); - - if (!catboost_handler) - { - processError(response, "CatBoost library is not loaded for model " + model_path); - return; - } - + auto catboost_handler = CatBoostLibraryHandlerFactory::instance().getOrCreateModel(model_path, library_path, /*create_if_not_found*/ true); size_t tree_count = catboost_handler->getTreeCount(); writeIntBinary(tree_count, out); } @@ -526,11 +558,11 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ for (const auto & p : col_ptrs) col_raw_ptrs.push_back(&*p); - auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path); + auto catboost_handler = CatBoostLibraryHandlerFactory::instance().getOrCreateModel(model_path, "DummyLibraryPath", /*create_if_not_found*/ false); if (!catboost_handler) { - processError(response, "CatBoost library is not loaded for model" + model_path); + processError(response, "CatBoost library is not loaded for model '" + model_path + "'. Please try again."); return; } diff --git a/programs/library-bridge/LibraryBridgeHandlers.h b/programs/library-bridge/LibraryBridgeHandlers.h index 5e171b882e2..16815e84723 100644 --- a/programs/library-bridge/LibraryBridgeHandlers.h +++ b/programs/library-bridge/LibraryBridgeHandlers.h @@ -3,7 +3,6 @@ #include #include #include -#include namespace DB @@ -46,16 +45,21 @@ private: /// Handler for requests to catboost library. The call protocol is as follows: -/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and -/// a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was -/// loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler -/// and finally sends the result back to the server. -/// Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in -/// the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than -/// once and if "model.bin" was updated in between. -/// (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on. -/// Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library -/// handler for the given model path is expected to be already loaded by Step (1). +/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge. It contains a library path (e.g /home/user/libcatboost.so) and +/// a model path (e.g. /home/user/model.bin). This loads the catboost library handler associated with the model path, then executes +/// GetTreeCount() on the library handler and sends the result back to the server. +/// (2) Send "catboost_Evaluate" from the server to the bridge. It contains a model path and the features to run the interference on. Step +/// (2) is called multiple times (once per chunk) by the server. +/// +/// We would ideally like to have steps (1) and (2) in one atomic handler but can't because the evaluation on the server side is divided +/// into two dependent phases: FunctionCatBoostEvaluate::getReturnTypeImpl() and ::executeImpl(). So the model may in principle be unloaded +/// from the library-bridge between steps (1) and (2). Step (2) checks if that is the case and fails gracefully. This is okay because that +/// situation considered exceptional and rare. +/// +/// An update of a model is performed by unloading it. The first call to "catboost_GetTreeCount" brings it into memory again. +/// +/// Further handlers are provided for unloading a specific model, for unloading all models or for retrieving information about the loaded +/// models for display in a system view. class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext { public: @@ -64,7 +68,6 @@ public: void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; private: - std::mutex mutex; const size_t keep_alive_timeout; Poco::Logger * log; }; diff --git a/src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp b/src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp index e3693ed50e0..28ed12ca6a7 100644 --- a/src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp +++ b/src/BridgeHelper/CatBoostLibraryBridgeHelper.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace DB { @@ -28,6 +30,16 @@ CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper( { } +CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view model_path_) + : CatBoostLibraryBridgeHelper(context_, "", model_path_) +{ +} + +CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(ContextPtr context_) + : CatBoostLibraryBridgeHelper(context_, "", "") +{ +} + Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const { auto uri = createBaseURI(); @@ -71,6 +83,74 @@ bool CatBoostLibraryBridgeHelper::bridgeHandShake() return true; } +ExternalModelInfos CatBoostLibraryBridgeHelper::listModels() +{ + startBridgeSync(); + + ReadWriteBufferFromHTTP buf( + createRequestURI(CATBOOST_LIST_METHOD), + Poco::Net::HTTPRequest::HTTP_POST, + [](std::ostream &) {}, + http_timeouts, credentials); + + ExternalModelInfos result; + + UInt64 num_rows; + readIntBinary(num_rows, buf); + + for (UInt64 i = 0; i < num_rows; ++i) + { + ExternalModelInfo info; + + readStringBinary(info.model_path, buf); + readStringBinary(info.model_type, buf); + + UInt64 t; + readIntBinary(t, buf); + info.loading_start_time = std::chrono::system_clock::from_time_t(t); + + readIntBinary(t, buf); + info.loading_duration = std::chrono::milliseconds(t); + + result.push_back(info); + } + + return result; +} + +void CatBoostLibraryBridgeHelper::removeModel() +{ + startBridgeSync(); + + ReadWriteBufferFromHTTP buf( + createRequestURI(CATBOOST_REMOVEMODEL_METHOD), + Poco::Net::HTTPRequest::HTTP_POST, + [this](std::ostream & os) + { + os << "model_path=" << escapeForFileName(model_path); + }, + http_timeouts, credentials); + + String result; + readStringBinary(result, buf); + assert(result == "1"); +} + +void CatBoostLibraryBridgeHelper::removeAllModels() +{ + startBridgeSync(); + + ReadWriteBufferFromHTTP buf( + createRequestURI(CATBOOST_REMOVEALLMODELS_METHOD), + Poco::Net::HTTPRequest::HTTP_POST, + [](std::ostream &){}, + http_timeouts, credentials); + + String result; + readStringBinary(result, buf); + assert(result == "1"); +} + size_t CatBoostLibraryBridgeHelper::getTreeCount() { startBridgeSync(); @@ -85,9 +165,9 @@ size_t CatBoostLibraryBridgeHelper::getTreeCount() }, http_timeouts, credentials); - size_t res; - readIntBinary(res, buf); - return res; + size_t result; + readIntBinary(result, buf); + return result; } ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns) diff --git a/src/BridgeHelper/CatBoostLibraryBridgeHelper.h b/src/BridgeHelper/CatBoostLibraryBridgeHelper.h index 7a88f4ca368..b19d783ed5b 100644 --- a/src/BridgeHelper/CatBoostLibraryBridgeHelper.h +++ b/src/BridgeHelper/CatBoostLibraryBridgeHelper.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -17,9 +18,15 @@ public: static constexpr inline auto MAIN_HANDLER = "/catboost_request"; CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view library_path_, std::string_view model_path_); + CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view model_path_); + explicit CatBoostLibraryBridgeHelper(ContextPtr context_); + + ExternalModelInfos listModels(); + + void removeModel(); + void removeAllModels(); size_t getTreeCount(); - ColumnPtr evaluate(const ColumnsWithTypeAndName & columns); protected: @@ -30,6 +37,9 @@ protected: bool bridgeHandShake() override; private: + static constexpr inline auto CATBOOST_LIST_METHOD = "catboost_list"; + static constexpr inline auto CATBOOST_REMOVEMODEL_METHOD = "catboost_removeModel"; + static constexpr inline auto CATBOOST_REMOVEALLMODELS_METHOD = "catboost_removeAllModels"; static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount"; static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate"; diff --git a/src/Common/ExternalModelInfo.h b/src/Common/ExternalModelInfo.h new file mode 100644 index 00000000000..378e4984af6 --- /dev/null +++ b/src/Common/ExternalModelInfo.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace DB +{ + +/// Details about external machine learning model, used by clickhouse-server and clickhouse-library-bridge +struct ExternalModelInfo +{ + String model_path; + String model_type; + std::chrono::system_clock::time_point loading_start_time; /// serialized as std::time_t + std::chrono::milliseconds loading_duration; /// serialized as UInt64 +}; + +using ExternalModelInfos = std::vector; + +} diff --git a/src/Interpreters/InterpreterSystemQuery.cpp b/src/Interpreters/InterpreterSystemQuery.cpp index 9863130fce3..32866e63260 100644 --- a/src/Interpreters/InterpreterSystemQuery.cpp +++ b/src/Interpreters/InterpreterSystemQuery.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -373,10 +374,17 @@ BlockIO InterpreterSystemQuery::execute() break; } case Type::RELOAD_MODEL: + { + getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL); + auto bridge_helper = std::make_unique(getContext(), query.target_model); + bridge_helper->removeModel(); + break; + } case Type::RELOAD_MODELS: { getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL); - /// SYSTEM RELOAD MODEL(S) is no longer supported. For compat reasons, it is not completely removed but retained as no-op. + auto bridge_helper = std::make_unique(getContext()); + bridge_helper->removeAllModels(); break; } case Type::RELOAD_FUNCTION: diff --git a/src/Storages/System/StorageSystemModels.cpp b/src/Storages/System/StorageSystemModels.cpp new file mode 100644 index 00000000000..d06f97a3f54 --- /dev/null +++ b/src/Storages/System/StorageSystemModels.cpp @@ -0,0 +1,38 @@ +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +NamesAndTypesList StorageSystemModels::getNamesAndTypes() +{ + return { + { "model_path", std::make_shared() }, + { "type", std::make_shared() }, + { "loading_start_time", std::make_shared() }, + { "loading_duration", std::make_shared() }, + }; +} + +void StorageSystemModels::fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo &) const +{ + auto bridge_helper = std::make_unique(context); + ExternalModelInfos infos = bridge_helper->listModels(); + + for (const auto & info : infos) + { + res_columns[0]->insert(info.model_path); + res_columns[1]->insert(info.model_type); + res_columns[2]->insert(static_cast(std::chrono::system_clock::to_time_t(info.loading_start_time))); + res_columns[3]->insert(std::chrono::duration_cast>(info.loading_duration).count()); + } +} + +} diff --git a/src/Storages/System/StorageSystemModels.h b/src/Storages/System/StorageSystemModels.h new file mode 100644 index 00000000000..dfb6ad3de5a --- /dev/null +++ b/src/Storages/System/StorageSystemModels.h @@ -0,0 +1,25 @@ +#pragma once + +#include + + +namespace DB +{ + +class Context; + + +class StorageSystemModels final : public IStorageSystemOneBlock +{ +public: + std::string getName() const override { return "SystemModels"; } + + static NamesAndTypesList getNamesAndTypes(); + +protected: + using IStorageSystemOneBlock::IStorageSystemOneBlock; + + void fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo & query_info) const override; +}; + +} diff --git a/src/Storages/System/attachSystemTables.cpp b/src/Storages/System/attachSystemTables.cpp index fa70751ee19..ab1ffdf209a 100644 --- a/src/Storages/System/attachSystemTables.cpp +++ b/src/Storages/System/attachSystemTables.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -163,6 +164,7 @@ void attachSystemTablesServer(ContextPtr context, IDatabase & system_database, b attach(context, system_database, "distributed_ddl_queue"); attach(context, system_database, "distribution_queue"); attach(context, system_database, "dictionaries"); + attach(context, system_database, "models"); attach(context, system_database, "clusters"); attach(context, system_database, "graphite_retentions"); attach(context, system_database, "macros"); diff --git a/tests/integration/test_catboost_evaluate/test.py b/tests/integration/test_catboost_evaluate/test.py index 897252374da..a0915977ab6 100644 --- a/tests/integration/test_catboost_evaluate/test.py +++ b/tests/integration/test_catboost_evaluate/test.py @@ -44,6 +44,8 @@ def testConstantFeatures(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + result = instance.query( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" ) @@ -55,6 +57,8 @@ def testNonConstantFeatures(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + instance.query("DROP TABLE IF EXISTS T;") instance.query( "CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;" @@ -74,6 +78,8 @@ def testModelPathIsNotAConstString(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + err = instance.query_and_get_error( "select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);" ) @@ -98,6 +104,8 @@ def testWrongNumberOfFeatureArguments(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + err = instance.query_and_get_error( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');" ) @@ -116,6 +124,8 @@ def testFloatFeatureMustBeNumeric(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + err = instance.query_and_get_error( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);" ) @@ -126,6 +136,8 @@ def testCategoricalFeatureMustBeNumericOrString(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + err = instance.query_and_get_error( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);" ) @@ -136,6 +148,8 @@ def testOnLowCardinalityFeatures(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + # same but on domain-compressed data result = instance.query( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));" @@ -148,6 +162,8 @@ def testOnNullableFeatures(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + result = instance.query( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));" ) @@ -165,6 +181,8 @@ def testInvalidLibraryPath(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + # temporarily move library elsewhere instance.exec_in_container( [ @@ -196,6 +214,8 @@ def testInvalidModelPath(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + err = instance.query_and_get_error( "select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);" ) @@ -211,6 +231,8 @@ def testRecoveryAfterCrash(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + result = instance.query( "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" ) @@ -235,6 +257,8 @@ def testAmazonModelSingleRow(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + result = instance.query( "select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);" ) @@ -246,6 +270,8 @@ def testAmazonModelManyRows(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + result = instance.query("drop table if exists amazon") result = instance.query( @@ -272,6 +298,8 @@ def testModelUpdate(ch_cluster): if instance.is_built_with_memory_sanitizer(): pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + result = instance.query("system reload models") + query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" result = instance.query(query) @@ -294,13 +322,18 @@ def testModelUpdate(ch_cluster): ] ) - # since the amazon model has a different number of features than the simple model, we should get an error - err = instance.query_and_get_error(query) - assert ( - "Number of columns is different with number of features: columns size 11 float features size 0 + cat features size 9" - in err + # unload simple model + result = instance.query( + "system reload model '/etc/clickhouse-server/model/simple_model.bin'" ) + # load the simple-model-camouflaged amazon model + result = instance.query( + "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);" + ) + expected = "0.7774665009089274\n" + assert result == expected + # restore instance.exec_in_container( [ @@ -316,3 +349,54 @@ def testModelUpdate(ch_cluster): "mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin", ] ) + + +def testSystemModelsAndModelRefresh(ch_cluster): + if instance.is_built_with_memory_sanitizer(): + pytest.skip("Memory Sanitizer cannot work with third-party shared libraries") + + result = instance.query("system reload models") + + # check model system view + result = instance.query("select * from system.models") + expected = "" + assert result == expected + + # load simple model + result = instance.query( + "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);" + ) + expected = "-1.930268705869267\n" + assert result == expected + + # check model system view with one model loaded + result = instance.query("select * from system.models") + assert result.count("\n") == 1 + expected = "/etc/clickhouse-server/model/simple_model.bin" + assert expected in result + + # load amazon model + result = instance.query( + "select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);" + ) + expected = "0.7774665009089274\n" + assert result == expected + + # check model system view with one model loaded + result = instance.query("select * from system.models") + assert result.count("\n") == 2 + expected = "/etc/clickhouse-server/model/simple_model.bin" + assert expected in result + expected = "/etc/clickhouse-server/model/amazon_model.bin" + assert expected in result + + # unload simple model + result = instance.query( + "system reload model '/etc/clickhouse-server/model/simple_model.bin'" + ) + + # check model system view, it should not display the removed model + result = instance.query("select * from system.models") + assert result.count("\n") == 1 + expected = "/etc/clickhouse-server/model/amazon_model.bin" + assert expected in result