diff --git a/src/Access/AccessType.h b/src/Access/AccessType.h index 952cddba5f5..d5185b9931d 100644 --- a/src/Access/AccessType.h +++ b/src/Access/AccessType.h @@ -130,6 +130,7 @@ enum class AccessType M(SYSTEM_RELOAD_CONFIG, "RELOAD CONFIG", GLOBAL, SYSTEM_RELOAD) \ M(SYSTEM_RELOAD_SYMBOLS, "RELOAD SYMBOLS", GLOBAL, SYSTEM_RELOAD) \ M(SYSTEM_RELOAD_DICTIONARY, "SYSTEM RELOAD DICTIONARIES, RELOAD DICTIONARY, RELOAD DICTIONARIES", GLOBAL, SYSTEM_RELOAD) \ + M(SYSTEM_RELOAD_MODEL, "SYSTEM RELOAD MODELS, RELOAD MODEL, RELOAD MODELS", GLOBAL, SYSTEM_RELOAD) \ M(SYSTEM_RELOAD_EMBEDDED_DICTIONARIES, "RELOAD EMBEDDED DICTIONARIES", GLOBAL, SYSTEM_RELOAD) /* implicitly enabled by the grant SYSTEM_RELOAD_DICTIONARY ON *.* */\ M(SYSTEM_RELOAD, "", GROUP, SYSTEM) \ M(SYSTEM_MERGES, "SYSTEM STOP MERGES, SYSTEM START MERGES, STOP_MERGES, START MERGES", TABLE, SYSTEM) \ diff --git a/src/Interpreters/ExternalModelsLoader.h b/src/Interpreters/ExternalModelsLoader.h index ebf6de67540..f0a7592f4d3 100644 --- a/src/Interpreters/ExternalModelsLoader.h +++ b/src/Interpreters/ExternalModelsLoader.h @@ -20,9 +20,14 @@ public: /// Models will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds. explicit ExternalModelsLoader(ContextPtr context_); - ModelPtr getModel(const std::string & name) const + ModelPtr getModel(const std::string & model_name) const { - return std::static_pointer_cast(load(name)); + return std::static_pointer_cast(load(model_name)); + } + + void reloadModel(const std::string & model_name) const + { + loadOrReload(model_name); } protected: diff --git a/src/Interpreters/InterpreterSystemQuery.cpp b/src/Interpreters/InterpreterSystemQuery.cpp index 02d5296a9d2..117aa0da2da 100644 --- a/src/Interpreters/InterpreterSystemQuery.cpp +++ b/src/Interpreters/InterpreterSystemQuery.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -286,6 +287,7 @@ BlockIO InterpreterSystemQuery::execute() auto & external_dictionaries_loader = system_context->getExternalDictionariesLoader(); external_dictionaries_loader.reloadDictionary(query.target_dictionary, getContext()); + ExternalDictionariesLoader::resetAll(); break; } @@ -299,6 +301,22 @@ BlockIO InterpreterSystemQuery::execute() ExternalDictionariesLoader::resetAll(); break; } + case Type::RELOAD_MODEL: + { + getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL); + + auto & external_models_loader = system_context->getExternalModelsLoader(); + external_models_loader.reloadModel(query.target_model); + break; + } + case Type::RELOAD_MODELS: + { + getContext()->checkAccess(AccessType::SYSTEM_RELOAD_MODEL); + + auto & external_models_loader = system_context->getExternalModelsLoader(); + external_models_loader.reloadAllTriedToLoad(); + break; + } case Type::RELOAD_EMBEDDED_DICTIONARIES: getContext()->checkAccess(AccessType::SYSTEM_RELOAD_EMBEDDED_DICTIONARIES); system_context->getEmbeddedDictionaries().reload(); @@ -652,6 +670,12 @@ AccessRightsElements InterpreterSystemQuery::getRequiredAccessForDDLOnCluster() required_access.emplace_back(AccessType::SYSTEM_RELOAD_DICTIONARY); break; } + case Type::RELOAD_MODEL: [[fallthrough]]; + case Type::RELOAD_MODELS: + { + required_access.emplace_back(AccessType::SYSTEM_RELOAD_MODEL); + break; + } case Type::RELOAD_CONFIG: { required_access.emplace_back(AccessType::SYSTEM_RELOAD_CONFIG); diff --git a/src/Parsers/ASTSystemQuery.cpp b/src/Parsers/ASTSystemQuery.cpp index 71bda0c7709..c929383a256 100644 --- a/src/Parsers/ASTSystemQuery.cpp +++ b/src/Parsers/ASTSystemQuery.cpp @@ -54,6 +54,10 @@ const char * ASTSystemQuery::typeToString(Type type) return "RELOAD DICTIONARY"; case Type::RELOAD_DICTIONARIES: return "RELOAD DICTIONARIES"; + case Type::RELOAD_MODEL: + return "RELOAD MODEL"; + case Type::RELOAD_MODELS: + return "RELOAD MODELS"; case Type::RELOAD_EMBEDDED_DICTIONARIES: return "RELOAD EMBEDDED DICTIONARIES"; case Type::RELOAD_CONFIG: diff --git a/src/Parsers/ASTSystemQuery.h b/src/Parsers/ASTSystemQuery.h index 5bcdcc7875d..af3244573e4 100644 --- a/src/Parsers/ASTSystemQuery.h +++ b/src/Parsers/ASTSystemQuery.h @@ -36,6 +36,8 @@ public: SYNC_REPLICA, RELOAD_DICTIONARY, RELOAD_DICTIONARIES, + RELOAD_MODEL, + RELOAD_MODELS, RELOAD_EMBEDDED_DICTIONARIES, RELOAD_CONFIG, RELOAD_SYMBOLS, @@ -63,6 +65,7 @@ public: Type type = Type::UNKNOWN; String target_dictionary; + String target_model; String database; String table; String replica; diff --git a/src/Parsers/ParserSystemQuery.cpp b/src/Parsers/ParserSystemQuery.cpp index 491037da9a9..2fc168ea167 100644 --- a/src/Parsers/ParserSystemQuery.cpp +++ b/src/Parsers/ParserSystemQuery.cpp @@ -57,7 +57,35 @@ bool ParserSystemQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & return false; break; } + case Type::RELOAD_MODEL: + { + String cluster_str; + if (ParserKeyword{"ON"}.ignore(pos, expected)) + { + if (!ASTQueryWithOnCluster::parse(pos, cluster_str, expected)) + return false; + } + res->cluster = cluster_str; + ASTPtr ast; + if (ParserStringLiteral{}.parse(pos, ast, expected)) + { + res->target_model = ast->as().value.safeGet(); + } + else + { + ParserIdentifier model_parser; + ASTPtr model; + String target_model; + if (!model_parser.parse(pos, model, expected)) + return false; + + if (!tryGetIdentifierNameInto(model, res->target_model)) + return false; + } + + break; + } case Type::DROP_REPLICA: { ASTPtr ast; diff --git a/tests/integration/test_catboost_model_reload/__init__.py b/tests/integration/test_catboost_model_reload/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/test_catboost_model_reload/config/catboost_lib.xml b/tests/integration/test_catboost_model_reload/config/catboost_lib.xml new file mode 100644 index 00000000000..745be7cebe6 --- /dev/null +++ b/tests/integration/test_catboost_model_reload/config/catboost_lib.xml @@ -0,0 +1,3 @@ + + /etc/clickhouse-server/model/libcatboostmodel.so + diff --git a/tests/integration/test_catboost_model_reload/config/models_config.xml b/tests/integration/test_catboost_model_reload/config/models_config.xml new file mode 100644 index 00000000000..e84ca8b5285 --- /dev/null +++ b/tests/integration/test_catboost_model_reload/config/models_config.xml @@ -0,0 +1,3 @@ + + /etc/clickhouse-server/model/model_config.xml + diff --git a/tests/integration/test_catboost_model_reload/model/conjunction.cbm b/tests/integration/test_catboost_model_reload/model/conjunction.cbm new file mode 100644 index 00000000000..7b75fb5f886 Binary files /dev/null and b/tests/integration/test_catboost_model_reload/model/conjunction.cbm differ diff --git a/tests/integration/test_catboost_model_reload/model/disjunction.cbm b/tests/integration/test_catboost_model_reload/model/disjunction.cbm new file mode 100644 index 00000000000..8145c24637f Binary files /dev/null and b/tests/integration/test_catboost_model_reload/model/disjunction.cbm differ diff --git a/tests/integration/test_catboost_model_reload/model/libcatboostmodel.so b/tests/integration/test_catboost_model_reload/model/libcatboostmodel.so new file mode 100755 index 00000000000..388d9f887b4 Binary files /dev/null and b/tests/integration/test_catboost_model_reload/model/libcatboostmodel.so differ diff --git a/tests/integration/test_catboost_model_reload/model/model_config.xml b/tests/integration/test_catboost_model_reload/model/model_config.xml new file mode 100644 index 00000000000..7cbda165ce9 --- /dev/null +++ b/tests/integration/test_catboost_model_reload/model/model_config.xml @@ -0,0 +1,8 @@ + + + catboost + model + /etc/clickhouse-server/model/model.cbm + 0 + + diff --git a/tests/integration/test_catboost_model_reload/test.py b/tests/integration/test_catboost_model_reload/test.py new file mode 100644 index 00000000000..8283e6af975 --- /dev/null +++ b/tests/integration/test_catboost_model_reload/test.py @@ -0,0 +1,74 @@ +import os +import sys +import time + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) + +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) +node = cluster.add_instance('node', stay_alive=True, main_configs=['config/models_config.xml', 'config/catboost_lib.xml']) + +def copy_file_to_container(local_path, dist_path, container_id): + os.system("docker cp {local} {cont_id}:{dist}".format(local=local_path, cont_id=container_id, dist=dist_path)) + +@pytest.fixture(scope="module") +def started_cluster(): + try: + cluster.start() + + copy_file_to_container(os.path.join(SCRIPT_DIR, 'model/.'), '/etc/clickhouse-server/model', node.docker_id) + node.query("CREATE TABLE binary (x UInt64, y UInt64) ENGINE = TinyLog()") + node.query("INSERT INTO binary VALUES (1, 1), (1, 0), (0, 1), (0, 0)") + + node.restart_clickhouse() + + yield cluster + + finally: + cluster.shutdown() + +def test_model_reload(started_cluster): + node.exec_in_container(["bash", "-c", "rm -f /etc/clickhouse-server/model/model.cbm"]) + node.exec_in_container(["bash", "-c", "ln /etc/clickhouse-server/model/conjunction.cbm /etc/clickhouse-server/model/model.cbm"]) + node.query("SYSTEM RELOAD MODEL model") + + result = node.query(""" + WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability + SELECT if(probability > 0.5, 1, 0) FROM binary; + """) + assert result == '1\n0\n0\n0\n' + + node.exec_in_container(["bash", "-c", "rm /etc/clickhouse-server/model/model.cbm"]) + node.exec_in_container(["bash", "-c", "ln /etc/clickhouse-server/model/disjunction.cbm /etc/clickhouse-server/model/model.cbm"]) + node.query("SYSTEM RELOAD MODEL model") + + result = node.query(""" + WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability + SELECT if(probability > 0.5, 1, 0) FROM binary; + """) + assert result == '1\n1\n1\n0\n' + +def test_models_reload(started_cluster): + node.exec_in_container(["bash", "-c", "rm -f /etc/clickhouse-server/model/model.cbm"]) + node.exec_in_container(["bash", "-c", "ln /etc/clickhouse-server/model/conjunction.cbm /etc/clickhouse-server/model/model.cbm"]) + node.query("SYSTEM RELOAD MODELS") + + result = node.query(""" + WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability + SELECT if(probability > 0.5, 1, 0) FROM binary; + """) + assert result == '1\n0\n0\n0\n' + + node.exec_in_container(["bash", "-c", "rm /etc/clickhouse-server/model/model.cbm"]) + node.exec_in_container(["bash", "-c", "ln /etc/clickhouse-server/model/disjunction.cbm /etc/clickhouse-server/model/model.cbm"]) + node.query("SYSTEM RELOAD MODELS") + + result = node.query(""" + WITH modelEvaluate('model', toFloat64(x), toFloat64(y)) as prediction, exp(prediction) / (1 + exp(prediction)) as probability + SELECT if(probability > 0.5, 1, 0) FROM binary; + """) + assert result == '1\n1\n1\n0\n' \ No newline at end of file diff --git a/tests/queries/0_stateless/01271_show_privileges.reference b/tests/queries/0_stateless/01271_show_privileges.reference index c8b8662dc3e..892bd95d2d9 100644 --- a/tests/queries/0_stateless/01271_show_privileges.reference +++ b/tests/queries/0_stateless/01271_show_privileges.reference @@ -82,6 +82,7 @@ SYSTEM DROP CACHE ['DROP CACHE'] \N SYSTEM SYSTEM RELOAD CONFIG ['RELOAD CONFIG'] GLOBAL SYSTEM RELOAD SYSTEM RELOAD SYMBOLS ['RELOAD SYMBOLS'] GLOBAL SYSTEM RELOAD SYSTEM RELOAD DICTIONARY ['SYSTEM RELOAD DICTIONARIES','RELOAD DICTIONARY','RELOAD DICTIONARIES'] GLOBAL SYSTEM RELOAD +SYSTEM RELOAD MODEL ['SYSTEM RELOAD MODELS','RELOAD MODEL','RELOAD MODELS'] GLOBAL SYSTEM RELOAD SYSTEM RELOAD EMBEDDED DICTIONARIES ['RELOAD EMBEDDED DICTIONARIES'] GLOBAL SYSTEM RELOAD SYSTEM RELOAD [] \N SYSTEM SYSTEM MERGES ['SYSTEM STOP MERGES','SYSTEM START MERGES','STOP_MERGES','START MERGES'] TABLE SYSTEM