mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-28 02:21:59 +00:00
chore: restore SYSTEM RELOAD MODEL(S) and moniting view SYSTEM.MODELS
- This commit restores statements "SYSTEM RELOAD MODEL(S)" which provide a mechanism to update a model explicitly. It also saves potentially unnecessary reloads of a model from disk after it's initial load. To keep the complexity low, the semantics of "SYSTEM RELOAD MODEL(S) was changed from eager to lazy. This means that both statements previously immedately reloaded the specified/all models, whereas now the statements only trigger an unload and the first call to catboostEvaluate() does the actual load. - Monitoring view SYSTEM.MODELS is also restored but with some obsolete fields removed. The view was not documented in the past and for now it remains undocumented. The commit is thus not considered a breach of ClickHouse's public interface.
This commit is contained in:
parent
c16707ff00
commit
fac1be9700
@ -11,6 +11,8 @@ The list of available `SYSTEM` statements:
|
||||
- [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries)
|
||||
- [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries)
|
||||
- [RELOAD DICTIONARY](#query_language-system-reload-dictionary)
|
||||
- [RELOAD MODELS](#query_language-system-reload-models)
|
||||
- [RELOAD MODEL](#query_language-system-reload-model)
|
||||
- [RELOAD FUNCTIONS](#query_language-system-reload-functions)
|
||||
- [RELOAD FUNCTION](#query_language-system-reload-functions)
|
||||
- [DROP DNS CACHE](#query_language-system-drop-dns-cache)
|
||||
@ -65,6 +67,26 @@ The status of the dictionary can be checked by querying the `system.dictionaries
|
||||
SELECT name, status FROM system.dictionaries;
|
||||
```
|
||||
|
||||
## RELOAD MODELS
|
||||
|
||||
Reloads all [CatBoost](../../guides/developer/apply-catboost-model.md) models.
|
||||
|
||||
**Syntax**
|
||||
|
||||
```sql
|
||||
SYSTEM RELOAD MODELS [ON CLUSTER cluster_name]
|
||||
```
|
||||
|
||||
## RELOAD MODEL
|
||||
|
||||
Completely reloads a CatBoost model `model_path`.
|
||||
|
||||
**Syntax**
|
||||
|
||||
```sql
|
||||
SYSTEM RELOAD MODEL [ON CLUSTER cluster_name] <model_path>
|
||||
```
|
||||
|
||||
## RELOAD FUNCTIONS
|
||||
|
||||
Reloads all registered [executable user defined functions](../functions/index.md#executable-user-defined-functions) or one of them from a configuration file.
|
||||
|
@ -9,6 +9,8 @@ sidebar_label: SYSTEM
|
||||
- [RELOAD EMBEDDED DICTIONARIES](#query_language-system-reload-emdedded-dictionaries)
|
||||
- [RELOAD DICTIONARIES](#query_language-system-reload-dictionaries)
|
||||
- [RELOAD DICTIONARY](#query_language-system-reload-dictionary)
|
||||
- [RELOAD MODELS](#query_language-system-reload-models)
|
||||
- [RELOAD MODEL](#query_language-system-reload-model)
|
||||
- [RELOAD FUNCTIONS](#query_language-system-reload-functions)
|
||||
- [RELOAD FUNCTION](#query_language-system-reload-functions)
|
||||
- [DROP DNS CACHE](#query_language-system-drop-dns-cache)
|
||||
@ -62,6 +64,26 @@ sidebar_label: SYSTEM
|
||||
SELECT name, status FROM system.dictionaries;
|
||||
```
|
||||
|
||||
## RELOAD MODELS {#query_language-system-reload-models}
|
||||
|
||||
Перегружает все модели [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse).
|
||||
|
||||
**Синтаксис**
|
||||
|
||||
```sql
|
||||
SYSTEM RELOAD MODELS
|
||||
```
|
||||
|
||||
## RELOAD MODEL {#query_language-system-reload-model}
|
||||
|
||||
Полностью перегружает модель [CatBoost](../../guides/apply-catboost-model.md#applying-catboost-model-in-clickhouse) `model_path`.
|
||||
|
||||
**Синтаксис**
|
||||
|
||||
```sql
|
||||
SYSTEM RELOAD MODEL <model_path>
|
||||
```
|
||||
|
||||
## RELOAD FUNCTIONS {#query_language-system-reload-functions}
|
||||
|
||||
Перезагружает все зарегистрированные [исполняемые пользовательские функции](../functions/index.md#executable-user-defined-functions) или одну из них из файла конфигурации.
|
||||
|
@ -34,7 +34,8 @@ CatBoostLibraryHandler::APIHolder::APIHolder(SharedLibrary & lib)
|
||||
CatBoostLibraryHandler::CatBoostLibraryHandler(
|
||||
const std::string & library_path,
|
||||
const std::string & model_path)
|
||||
: library(std::make_shared<SharedLibrary>(library_path))
|
||||
: loading_start_time(std::chrono::system_clock::now())
|
||||
, library(std::make_shared<SharedLibrary>(library_path))
|
||||
, api(*library)
|
||||
{
|
||||
model_calcer_handle = api.ModelCalcerCreate();
|
||||
@ -51,6 +52,8 @@ CatBoostLibraryHandler::CatBoostLibraryHandler(
|
||||
tree_count = 1;
|
||||
if (api.GetDimensionsCount)
|
||||
tree_count = api.GetDimensionsCount(model_calcer_handle);
|
||||
|
||||
loading_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - loading_start_time);
|
||||
}
|
||||
|
||||
CatBoostLibraryHandler::~CatBoostLibraryHandler()
|
||||
@ -58,6 +61,16 @@ CatBoostLibraryHandler::~CatBoostLibraryHandler()
|
||||
api.ModelCalcerDelete(model_calcer_handle);
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point CatBoostLibraryHandler::getLoadingStartTime() const
|
||||
{
|
||||
return loading_start_time;
|
||||
}
|
||||
|
||||
std::chrono::milliseconds CatBoostLibraryHandler::getLoadingDuration() const
|
||||
{
|
||||
return loading_duration;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <Common/SharedLibrary.h>
|
||||
#include <base/defines.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
|
||||
namespace DB
|
||||
@ -42,16 +43,22 @@ public:
|
||||
};
|
||||
|
||||
CatBoostLibraryHandler(
|
||||
const std::string & library_path,
|
||||
const std::string & model_path);
|
||||
const String & library_path,
|
||||
const String & model_path);
|
||||
|
||||
~CatBoostLibraryHandler();
|
||||
|
||||
std::chrono::system_clock::time_point getLoadingStartTime() const;
|
||||
std::chrono::milliseconds getLoadingDuration() const;
|
||||
|
||||
size_t getTreeCount() const;
|
||||
|
||||
ColumnPtr evaluate(const ColumnRawPtrs & columns) const;
|
||||
|
||||
private:
|
||||
std::chrono::system_clock::time_point loading_start_time;
|
||||
std::chrono::milliseconds loading_duration;
|
||||
|
||||
const SharedLibraryPtr library;
|
||||
const APIHolder api;
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <Common/logger_useful.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -11,39 +12,64 @@ CatBoostLibraryHandlerFactory & CatBoostLibraryHandlerFactory::instance()
|
||||
return instance;
|
||||
}
|
||||
|
||||
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::get(const String & model_path)
|
||||
CatBoostLibraryHandlerPtr CatBoostLibraryHandlerFactory::getOrCreateModel(const String & model_path, const String & library_path, bool create_if_not_found)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (auto handler = library_handlers.find(model_path); handler != library_handlers.end())
|
||||
auto handler = library_handlers.find(model_path);
|
||||
bool found = (handler != library_handlers.end());
|
||||
|
||||
if (found)
|
||||
return handler->second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::create(const String & library_path, const String & model_path)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (library_handlers.contains(model_path))
|
||||
else
|
||||
{
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot load catboost library handler for model path {} because it exists already", model_path);
|
||||
return;
|
||||
if (create_if_not_found)
|
||||
{
|
||||
auto new_handler = std::make_shared<CatBoostLibraryHandler>(library_path, model_path);
|
||||
library_handlers.emplace(std::make_pair(model_path, new_handler));
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path '{}'", model_path);
|
||||
return new_handler;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
library_handlers.emplace(std::make_pair(model_path, std::make_shared<CatBoostLibraryHandler>(library_path, model_path)));
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Loaded catboost library handler for model path {}.", model_path);
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::remove(const String & model_path)
|
||||
void CatBoostLibraryHandlerFactory::removeModel(const String & model_path)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
bool deleted = library_handlers.erase(model_path);
|
||||
if (!deleted)
|
||||
{
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path: {}", model_path);
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Cannot unload catboost library handler for model path '{}'", model_path);
|
||||
return;
|
||||
}
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path: {}", model_path);
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded catboost library handler for model path '{}'", model_path);
|
||||
}
|
||||
|
||||
void CatBoostLibraryHandlerFactory::removeAllModels()
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
library_handlers.clear();
|
||||
LOG_DEBUG(&Poco::Logger::get("CatBoostLibraryHandlerFactory"), "Unloaded all catboost library handlers");
|
||||
}
|
||||
|
||||
ExternalModelInfos CatBoostLibraryHandlerFactory::getModelInfos()
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
ExternalModelInfos result;
|
||||
|
||||
for (const auto & handler : library_handlers)
|
||||
result.push_back({
|
||||
.model_path = handler.first,
|
||||
.model_type = "catboost",
|
||||
.loading_start_time = handler.second->getLoadingStartTime(),
|
||||
.loading_duration = handler.second->getLoadingDuration()
|
||||
});
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -3,7 +3,9 @@
|
||||
#include "CatBoostLibraryHandler.h"
|
||||
|
||||
#include <base/defines.h>
|
||||
#include <Common/ExternalModelInfo.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
@ -16,14 +18,15 @@ class CatBoostLibraryHandlerFactory final : private boost::noncopyable
|
||||
public:
|
||||
static CatBoostLibraryHandlerFactory & instance();
|
||||
|
||||
CatBoostLibraryHandlerPtr get(const String & model_path);
|
||||
CatBoostLibraryHandlerPtr getOrCreateModel(const String & model_path, const String & library_path, bool create_if_not_found);
|
||||
|
||||
void create(const String & library_path, const String & model_path);
|
||||
void removeModel(const String & model_path);
|
||||
void removeAllModels();
|
||||
|
||||
void remove(const String & model_path);
|
||||
ExternalModelInfos getModelInfos();
|
||||
|
||||
private:
|
||||
/// map: model path -> shared library handler
|
||||
/// map: model path --> catboost library handler
|
||||
std::unordered_map<String, CatBoostLibraryHandlerPtr> library_handlers TSA_GUARDED_BY(mutex);
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <Formats/NativeReader.h>
|
||||
#include <Formats/NativeWriter.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -422,8 +423,6 @@ CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
|
||||
|
||||
void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
LOG_TRACE(log, "Request URI: {}", request.getURI());
|
||||
HTMLForm params(getContext()->getSettingsRef(), request);
|
||||
|
||||
@ -460,7 +459,51 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
|
||||
|
||||
try
|
||||
{
|
||||
if (method == "catboost_GetTreeCount")
|
||||
if (method == "catboost_list")
|
||||
{
|
||||
ExternalModelInfos model_infos = CatBoostLibraryHandlerFactory::instance().getModelInfos();
|
||||
|
||||
writeIntBinary(static_cast<UInt64>(model_infos.size()), out);
|
||||
|
||||
for (const auto & info : model_infos)
|
||||
{
|
||||
writeStringBinary(info.model_path, out);
|
||||
writeStringBinary(info.model_type, out);
|
||||
|
||||
UInt64 t = std::chrono::system_clock::to_time_t(info.loading_start_time);
|
||||
writeIntBinary(t, out);
|
||||
|
||||
t = info.loading_duration.count();
|
||||
writeIntBinary(t, out);
|
||||
|
||||
}
|
||||
}
|
||||
else if (method == "catboost_removeModel")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
|
||||
if (!params.has("model_path"))
|
||||
{
|
||||
processError(response, "No 'model_path' in request URL");
|
||||
return;
|
||||
}
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
CatBoostLibraryHandlerFactory::instance().removeModel(model_path);
|
||||
|
||||
String res = "1";
|
||||
writeStringBinary(res, out);
|
||||
}
|
||||
else if (method == "catboost_removeAllModels")
|
||||
{
|
||||
CatBoostLibraryHandlerFactory::instance().removeAllModels();
|
||||
|
||||
String res = "1";
|
||||
writeStringBinary(res, out);
|
||||
}
|
||||
else if (method == "catboost_GetTreeCount")
|
||||
{
|
||||
auto & read_buf = request.getStream();
|
||||
params.read(read_buf);
|
||||
@ -481,18 +524,7 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
|
||||
|
||||
const String & model_path = params.get("model_path");
|
||||
|
||||
CatBoostLibraryHandlerFactory::instance().remove(model_path);
|
||||
|
||||
CatBoostLibraryHandlerFactory::instance().create(library_path, model_path);
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
|
||||
|
||||
if (!catboost_handler)
|
||||
{
|
||||
processError(response, "CatBoost library is not loaded for model " + model_path);
|
||||
return;
|
||||
}
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().getOrCreateModel(model_path, library_path, /*create_if_not_found*/ true);
|
||||
size_t tree_count = catboost_handler->getTreeCount();
|
||||
writeIntBinary(tree_count, out);
|
||||
}
|
||||
@ -526,11 +558,11 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
|
||||
for (const auto & p : col_ptrs)
|
||||
col_raw_ptrs.push_back(&*p);
|
||||
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().get(model_path);
|
||||
auto catboost_handler = CatBoostLibraryHandlerFactory::instance().getOrCreateModel(model_path, "DummyLibraryPath", /*create_if_not_found*/ false);
|
||||
|
||||
if (!catboost_handler)
|
||||
{
|
||||
processError(response, "CatBoost library is not loaded for model" + model_path);
|
||||
processError(response, "CatBoost library is not loaded for model '" + model_path + "'. Please try again.");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <Common/logger_useful.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Server/HTTP/HTTPRequestHandler.h>
|
||||
#include <mutex>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -46,16 +45,21 @@ private:
|
||||
|
||||
|
||||
/// Handler for requests to catboost library. The call protocol is as follows:
|
||||
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge, containing a library path (e.g /home/user/libcatboost.so) and
|
||||
/// a model path (e.g. /home/user/model.bin). Rirst, this unloads the catboost library handler associated to the model path (if it was
|
||||
/// loaded), then loads the catboost library handler associated to the model path, then executes GetTreeCount() on the library handler
|
||||
/// and finally sends the result back to the server.
|
||||
/// Step (1) is called once by the server from FunctionCatBoostEvaluate::getReturnTypeImpl(). The library path handler is unloaded in
|
||||
/// the beginning because it contains state which may no longer be valid if the user runs catboost("/path/to/model.bin", ...) more than
|
||||
/// once and if "model.bin" was updated in between.
|
||||
/// (2) Send "catboost_Evaluate" from the server to the bridge, containing the model path and the features to run the interference on.
|
||||
/// Step (2) is called multiple times (once per chunk) by the server from function FunctionCatBoostEvaluate::executeImpl(). The library
|
||||
/// handler for the given model path is expected to be already loaded by Step (1).
|
||||
/// (1) Send a "catboost_GetTreeCount" request from the server to the bridge. It contains a library path (e.g /home/user/libcatboost.so) and
|
||||
/// a model path (e.g. /home/user/model.bin). This loads the catboost library handler associated with the model path, then executes
|
||||
/// GetTreeCount() on the library handler and sends the result back to the server.
|
||||
/// (2) Send "catboost_Evaluate" from the server to the bridge. It contains a model path and the features to run the interference on. Step
|
||||
/// (2) is called multiple times (once per chunk) by the server.
|
||||
///
|
||||
/// We would ideally like to have steps (1) and (2) in one atomic handler but can't because the evaluation on the server side is divided
|
||||
/// into two dependent phases: FunctionCatBoostEvaluate::getReturnTypeImpl() and ::executeImpl(). So the model may in principle be unloaded
|
||||
/// from the library-bridge between steps (1) and (2). Step (2) checks if that is the case and fails gracefully. This is okay because that
|
||||
/// situation considered exceptional and rare.
|
||||
///
|
||||
/// An update of a model is performed by unloading it. The first call to "catboost_GetTreeCount" brings it into memory again.
|
||||
///
|
||||
/// Further handlers are provided for unloading a specific model, for unloading all models or for retrieving information about the loaded
|
||||
/// models for display in a system view.
|
||||
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
|
||||
{
|
||||
public:
|
||||
@ -64,7 +68,6 @@ public:
|
||||
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
|
||||
|
||||
private:
|
||||
std::mutex mutex;
|
||||
const size_t keep_alive_timeout;
|
||||
Poco::Logger * log;
|
||||
};
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <Poco/Net/HTTPRequest.h>
|
||||
|
||||
#include <random>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -28,6 +30,16 @@ CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(
|
||||
{
|
||||
}
|
||||
|
||||
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view model_path_)
|
||||
: CatBoostLibraryBridgeHelper(context_, "", model_path_)
|
||||
{
|
||||
}
|
||||
|
||||
CatBoostLibraryBridgeHelper::CatBoostLibraryBridgeHelper(ContextPtr context_)
|
||||
: CatBoostLibraryBridgeHelper(context_, "", "")
|
||||
{
|
||||
}
|
||||
|
||||
Poco::URI CatBoostLibraryBridgeHelper::getPingURI() const
|
||||
{
|
||||
auto uri = createBaseURI();
|
||||
@ -71,6 +83,74 @@ bool CatBoostLibraryBridgeHelper::bridgeHandShake()
|
||||
return true;
|
||||
}
|
||||
|
||||
ExternalModelInfos CatBoostLibraryBridgeHelper::listModels()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_LIST_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[](std::ostream &) {},
|
||||
http_timeouts, credentials);
|
||||
|
||||
ExternalModelInfos result;
|
||||
|
||||
UInt64 num_rows;
|
||||
readIntBinary(num_rows, buf);
|
||||
|
||||
for (UInt64 i = 0; i < num_rows; ++i)
|
||||
{
|
||||
ExternalModelInfo info;
|
||||
|
||||
readStringBinary(info.model_path, buf);
|
||||
readStringBinary(info.model_type, buf);
|
||||
|
||||
UInt64 t;
|
||||
readIntBinary(t, buf);
|
||||
info.loading_start_time = std::chrono::system_clock::from_time_t(t);
|
||||
|
||||
readIntBinary(t, buf);
|
||||
info.loading_duration = std::chrono::milliseconds(t);
|
||||
|
||||
result.push_back(info);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void CatBoostLibraryBridgeHelper::removeModel()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_REMOVEMODEL_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[this](std::ostream & os)
|
||||
{
|
||||
os << "model_path=" << escapeForFileName(model_path);
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
String result;
|
||||
readStringBinary(result, buf);
|
||||
assert(result == "1");
|
||||
}
|
||||
|
||||
void CatBoostLibraryBridgeHelper::removeAllModels()
|
||||
{
|
||||
startBridgeSync();
|
||||
|
||||
ReadWriteBufferFromHTTP buf(
|
||||
createRequestURI(CATBOOST_REMOVEALLMODELS_METHOD),
|
||||
Poco::Net::HTTPRequest::HTTP_POST,
|
||||
[](std::ostream &){},
|
||||
http_timeouts, credentials);
|
||||
|
||||
String result;
|
||||
readStringBinary(result, buf);
|
||||
assert(result == "1");
|
||||
}
|
||||
|
||||
size_t CatBoostLibraryBridgeHelper::getTreeCount()
|
||||
{
|
||||
startBridgeSync();
|
||||
@ -85,9 +165,9 @@ size_t CatBoostLibraryBridgeHelper::getTreeCount()
|
||||
},
|
||||
http_timeouts, credentials);
|
||||
|
||||
size_t res;
|
||||
readIntBinary(res, buf);
|
||||
return res;
|
||||
size_t result;
|
||||
readIntBinary(result, buf);
|
||||
return result;
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostLibraryBridgeHelper::evaluate(const ColumnsWithTypeAndName & columns)
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <BridgeHelper/LibraryBridgeHelper.h>
|
||||
#include <Common/ExternalModelInfo.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
#include <IO/ReadWriteBufferFromHTTP.h>
|
||||
#include <Interpreters/Context.h>
|
||||
@ -17,9 +18,15 @@ public:
|
||||
static constexpr inline auto MAIN_HANDLER = "/catboost_request";
|
||||
|
||||
CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view library_path_, std::string_view model_path_);
|
||||
CatBoostLibraryBridgeHelper(ContextPtr context_, std::string_view model_path_);
|
||||
explicit CatBoostLibraryBridgeHelper(ContextPtr context_);
|
||||
|
||||
ExternalModelInfos listModels();
|
||||
|
||||
void removeModel();
|
||||
void removeAllModels();
|
||||
|
||||
size_t getTreeCount();
|
||||
|
||||
ColumnPtr evaluate(const ColumnsWithTypeAndName & columns);
|
||||
|
||||
protected:
|
||||
@ -30,6 +37,9 @@ protected:
|
||||
bool bridgeHandShake() override;
|
||||
|
||||
private:
|
||||
static constexpr inline auto CATBOOST_LIST_METHOD = "catboost_list";
|
||||
static constexpr inline auto CATBOOST_REMOVEMODEL_METHOD = "catboost_removeModel";
|
||||
static constexpr inline auto CATBOOST_REMOVEALLMODELS_METHOD = "catboost_removeAllModels";
|
||||
static constexpr inline auto CATBOOST_GETTREECOUNT_METHOD = "catboost_GetTreeCount";
|
||||
static constexpr inline auto CATBOOST_LIB_EVALUATE_METHOD = "catboost_libEvaluate";
|
||||
|
||||
|
20
src/Common/ExternalModelInfo.h
Normal file
20
src/Common/ExternalModelInfo.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <base/types.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Details about external machine learning model, used by clickhouse-server and clickhouse-library-bridge
|
||||
struct ExternalModelInfo
|
||||
{
|
||||
String model_path;
|
||||
String model_type;
|
||||
std::chrono::system_clock::time_point loading_start_time; /// serialized as std::time_t
|
||||
std::chrono::milliseconds loading_duration; /// serialized as UInt64
|
||||
};
|
||||
|
||||
using ExternalModelInfos = std::vector<ExternalModelInfo>;
|
||||
|
||||
}
|
@ -35,6 +35,7 @@
|
||||
#include <Interpreters/ProcessorsProfileLog.h>
|
||||
#include <Interpreters/JIT/CompiledExpressionCache.h>
|
||||
#include <Interpreters/TransactionLog.h>
|
||||
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
|
||||
#include <Access/ContextAccess.h>
|
||||
#include <Access/Common/AllowedClientHosts.h>
|
||||
#include <Databases/IDatabase.h>
|
||||
@ -373,10 +374,17 @@ BlockIO InterpreterSystemQuery::execute()
|
||||
break;
|
||||
}
|
||||
case Type::RELOAD_MODEL:
|
||||
{
|
||||
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
|
||||
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), query.target_model);
|
||||
bridge_helper->removeModel();
|
||||
break;
|
||||
}
|
||||
case Type::RELOAD_MODELS:
|
||||
{
|
||||
getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL);
|
||||
/// SYSTEM RELOAD MODEL(S) is no longer supported. For compat reasons, it is not completely removed but retained as no-op.
|
||||
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext());
|
||||
bridge_helper->removeAllModels();
|
||||
break;
|
||||
}
|
||||
case Type::RELOAD_FUNCTION:
|
||||
|
38
src/Storages/System/StorageSystemModels.cpp
Normal file
38
src/Storages/System/StorageSystemModels.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
#include <Storages/System/StorageSystemModels.h>
|
||||
#include <Common/ExternalModelInfo.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeEnum.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
NamesAndTypesList StorageSystemModels::getNamesAndTypes()
|
||||
{
|
||||
return {
|
||||
{ "model_path", std::make_shared<DataTypeString>() },
|
||||
{ "type", std::make_shared<DataTypeString>() },
|
||||
{ "loading_start_time", std::make_shared<DataTypeDateTime>() },
|
||||
{ "loading_duration", std::make_shared<DataTypeFloat32>() },
|
||||
};
|
||||
}
|
||||
|
||||
void StorageSystemModels::fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo &) const
|
||||
{
|
||||
auto bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(context);
|
||||
ExternalModelInfos infos = bridge_helper->listModels();
|
||||
|
||||
for (const auto & info : infos)
|
||||
{
|
||||
res_columns[0]->insert(info.model_path);
|
||||
res_columns[1]->insert(info.model_type);
|
||||
res_columns[2]->insert(static_cast<UInt64>(std::chrono::system_clock::to_time_t(info.loading_start_time)));
|
||||
res_columns[3]->insert(std::chrono::duration_cast<std::chrono::duration<float>>(info.loading_duration).count());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
25
src/Storages/System/StorageSystemModels.h
Normal file
25
src/Storages/System/StorageSystemModels.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <Storages/System/IStorageSystemOneBlock.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class Context;
|
||||
|
||||
|
||||
class StorageSystemModels final : public IStorageSystemOneBlock<StorageSystemModels>
|
||||
{
|
||||
public:
|
||||
std::string getName() const override { return "SystemModels"; }
|
||||
|
||||
static NamesAndTypesList getNamesAndTypes();
|
||||
|
||||
protected:
|
||||
using IStorageSystemOneBlock::IStorageSystemOneBlock;
|
||||
|
||||
void fillData(MutableColumns & res_columns, ContextPtr context, const SelectQueryInfo & query_info) const override;
|
||||
};
|
||||
|
||||
}
|
@ -25,6 +25,7 @@
|
||||
#include <Storages/System/StorageSystemMerges.h>
|
||||
#include <Storages/System/StorageSystemReplicatedFetches.h>
|
||||
#include <Storages/System/StorageSystemMetrics.h>
|
||||
#include <Storages/System/StorageSystemModels.h>
|
||||
#include <Storages/System/StorageSystemMutations.h>
|
||||
#include <Storages/System/StorageSystemNumbers.h>
|
||||
#include <Storages/System/StorageSystemOne.h>
|
||||
@ -163,6 +164,7 @@ void attachSystemTablesServer(ContextPtr context, IDatabase & system_database, b
|
||||
attach<StorageSystemDDLWorkerQueue>(context, system_database, "distributed_ddl_queue");
|
||||
attach<StorageSystemDistributionQueue>(context, system_database, "distribution_queue");
|
||||
attach<StorageSystemDictionaries>(context, system_database, "dictionaries");
|
||||
attach<StorageSystemModels>(context, system_database, "models");
|
||||
attach<StorageSystemClusters>(context, system_database, "clusters");
|
||||
attach<StorageSystemGraphite>(context, system_database, "graphite_retentions");
|
||||
attach<StorageSystemMacros>(context, system_database, "macros");
|
||||
|
@ -44,6 +44,8 @@ def testConstantFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
@ -55,6 +57,8 @@ def testNonConstantFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
instance.query("DROP TABLE IF EXISTS T;")
|
||||
instance.query(
|
||||
"CREATE TABLE T(ID UInt32, F1 Float32, F2 Float32, F3 UInt32, F4 UInt32, F5 UInt32, F6 UInt32, F7 UInt32, F8 UInt32, F9 Float32, F10 Float32, F11 Float32) ENGINE MergeTree ORDER BY ID;"
|
||||
@ -74,6 +78,8 @@ def testModelPathIsNotAConstString(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
@ -98,6 +104,8 @@ def testWrongNumberOfFeatureArguments(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin');"
|
||||
)
|
||||
@ -116,6 +124,8 @@ def testFloatFeatureMustBeNumeric(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 'a', 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
@ -126,6 +136,8 @@ def testCategoricalFeatureMustBeNumericOrString(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, tuple(8), 9, 10, 11);"
|
||||
)
|
||||
@ -136,6 +148,8 @@ def testOnLowCardinalityFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
# same but on domain-compressed data
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toLowCardinality(1.0), toLowCardinality(2.0), toLowCardinality(3), toLowCardinality(4), toLowCardinality(5), toLowCardinality(6), toLowCardinality(7), toLowCardinality(8), toLowCardinality(9), toLowCardinality(10), toLowCardinality(11));"
|
||||
@ -148,6 +162,8 @@ def testOnNullableFeatures(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', toNullable(1.0), toNullable(2.0), toNullable(3), toNullable(4), toNullable(5), toNullable(6), toNullable(7), toNullable(8), toNullable(9), toNullable(10), toNullable(11));"
|
||||
)
|
||||
@ -165,6 +181,8 @@ def testInvalidLibraryPath(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
# temporarily move library elsewhere
|
||||
instance.exec_in_container(
|
||||
[
|
||||
@ -196,6 +214,8 @@ def testInvalidModelPath(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
err = instance.query_and_get_error(
|
||||
"select catboostEvaluate('', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
@ -211,6 +231,8 @@ def testRecoveryAfterCrash(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
@ -235,6 +257,8 @@ def testAmazonModelSingleRow(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
@ -246,6 +270,8 @@ def testAmazonModelManyRows(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
result = instance.query("drop table if exists amazon")
|
||||
|
||||
result = instance.query(
|
||||
@ -272,6 +298,8 @@ def testModelUpdate(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
query = "select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
|
||||
result = instance.query(query)
|
||||
@ -294,13 +322,18 @@ def testModelUpdate(ch_cluster):
|
||||
]
|
||||
)
|
||||
|
||||
# since the amazon model has a different number of features than the simple model, we should get an error
|
||||
err = instance.query_and_get_error(query)
|
||||
assert (
|
||||
"Number of columns is different with number of features: columns size 11 float features size 0 + cat features size 9"
|
||||
in err
|
||||
# unload simple model
|
||||
result = instance.query(
|
||||
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
|
||||
)
|
||||
|
||||
# load the simple-model-camouflaged amazon model
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
expected = "0.7774665009089274\n"
|
||||
assert result == expected
|
||||
|
||||
# restore
|
||||
instance.exec_in_container(
|
||||
[
|
||||
@ -316,3 +349,54 @@ def testModelUpdate(ch_cluster):
|
||||
"mv /etc/clickhouse-server/model/simple_model.bin.bak /etc/clickhouse-server/model/simple_model.bin",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def testSystemModelsAndModelRefresh(ch_cluster):
|
||||
if instance.is_built_with_memory_sanitizer():
|
||||
pytest.skip("Memory Sanitizer cannot work with third-party shared libraries")
|
||||
|
||||
result = instance.query("system reload models")
|
||||
|
||||
# check model system view
|
||||
result = instance.query("select * from system.models")
|
||||
expected = ""
|
||||
assert result == expected
|
||||
|
||||
# load simple model
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/simple_model.bin', 1.0, 2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11);"
|
||||
)
|
||||
expected = "-1.930268705869267\n"
|
||||
assert result == expected
|
||||
|
||||
# check model system view with one model loaded
|
||||
result = instance.query("select * from system.models")
|
||||
assert result.count("\n") == 1
|
||||
expected = "/etc/clickhouse-server/model/simple_model.bin"
|
||||
assert expected in result
|
||||
|
||||
# load amazon model
|
||||
result = instance.query(
|
||||
"select catboostEvaluate('/etc/clickhouse-server/model/amazon_model.bin', 1, 2, 3, 4, 5, 6, 7, 8, 9);"
|
||||
)
|
||||
expected = "0.7774665009089274\n"
|
||||
assert result == expected
|
||||
|
||||
# check model system view with one model loaded
|
||||
result = instance.query("select * from system.models")
|
||||
assert result.count("\n") == 2
|
||||
expected = "/etc/clickhouse-server/model/simple_model.bin"
|
||||
assert expected in result
|
||||
expected = "/etc/clickhouse-server/model/amazon_model.bin"
|
||||
assert expected in result
|
||||
|
||||
# unload simple model
|
||||
result = instance.query(
|
||||
"system reload model '/etc/clickhouse-server/model/simple_model.bin'"
|
||||
)
|
||||
|
||||
# check model system view, it should not display the removed model
|
||||
result = instance.query("select * from system.models")
|
||||
assert result.count("\n") == 1
|
||||
expected = "/etc/clickhouse-server/model/amazon_model.bin"
|
||||
assert expected in result
|
||||
|
Loading…
Reference in New Issue
Block a user