From bcc0fbbf9150de48fdec116f62c0664b84c13c9f Mon Sep 17 00:00:00 2001 From: Nikolay Degterinsky Date: Thu, 10 Aug 2023 16:10:33 +0000 Subject: [PATCH 1/2] Add EXCEPT clause to SYSTEM STOP LISTEN query --- docs/en/sql-reference/statements/system.md | 6 +- programs/server/Server.cpp | 6 +- src/Parsers/ASTSystemQuery.cpp | 48 +++++++++-- src/Parsers/ParserSystemQuery.cpp | 81 ++++++++++++++----- src/Server/ServerType.cpp | 58 +++++++++---- src/Server/ServerType.h | 18 ++++- .../test_system_start_stop_listen/test.py | 70 ++++++++++++++++ 7 files changed, 238 insertions(+), 49 deletions(-) diff --git a/docs/en/sql-reference/statements/system.md b/docs/en/sql-reference/statements/system.md index 59970dbeccd..766dd2348ee 100644 --- a/docs/en/sql-reference/statements/system.md +++ b/docs/en/sql-reference/statements/system.md @@ -443,9 +443,9 @@ SYSTEM STOP LISTEN [ON CLUSTER cluster_name] [QUERIES ALL | QUERIES DEFAULT | QU ``` - If `CUSTOM 'protocol'` modifier is specified, the custom protocol with the specified name defined in the protocols section of the server configuration will be stopped. -- If `QUERIES ALL` modifier is specified, all protocols are stopped. -- If `QUERIES DEFAULT` modifier is specified, all default protocols are stopped. -- If `QUERIES CUSTOM` modifier is specified, all custom protocols are stopped. +- If `QUERIES ALL [EXCEPT .. [,..]]` modifier is specified, all protocols are stopped, unless specified with `EXCEPT` clause. +- If `QUERIES DEFAULT [EXCEPT .. [,..]]` modifier is specified, all default protocols are stopped, unless specified with `EXCEPT` clause. +- If `QUERIES CUSTOM [EXCEPT .. [,..]]` modifier is specified, all custom protocols are stopped, unless specified with `EXCEPT` clause. ### SYSTEM START LISTEN diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index e6d5837dd0e..bdff3b79a99 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -2045,6 +2045,9 @@ void Server::createServers( for (const auto & protocol : protocols) { + if (!server_type.shouldStart(ServerType::Type::CUSTOM, protocol)) + continue; + std::string prefix = "protocols." + protocol + "."; std::string port_name = prefix + "port"; std::string description {" protocol"}; @@ -2054,9 +2057,6 @@ void Server::createServers( if (!config.has(prefix + "port")) continue; - if (!server_type.shouldStart(ServerType::Type::CUSTOM, port_name)) - continue; - std::vector hosts; if (config.has(prefix + "host")) hosts.push_back(config.getString(prefix + "host")); diff --git a/src/Parsers/ASTSystemQuery.cpp b/src/Parsers/ASTSystemQuery.cpp index fb10474a4d4..9be01719d8c 100644 --- a/src/Parsers/ASTSystemQuery.cpp +++ b/src/Parsers/ASTSystemQuery.cpp @@ -204,7 +204,7 @@ void ASTSystemQuery::formatImpl(const FormatSettings & settings, FormatState &, } else if (type == Type::SUSPEND) { - settings.ostr << (settings.hilite ? hilite_keyword : "") << " FOR " + settings.ostr << (settings.hilite ? hilite_keyword : "") << " FOR " << (settings.hilite ? hilite_none : "") << seconds << (settings.hilite ? hilite_keyword : "") << " SECOND" << (settings.hilite ? hilite_none : ""); @@ -232,12 +232,50 @@ void ASTSystemQuery::formatImpl(const FormatSettings & settings, FormatState &, } else if (type == Type::START_LISTEN || type == Type::STOP_LISTEN) { - settings.ostr << (settings.hilite ? hilite_keyword : "") << " " << ServerType::serverTypeToString(server_type.type) - << (settings.hilite ? hilite_none : ""); + settings.ostr << (settings.hilite ? hilite_keyword : "") << " " + << ServerType::serverTypeToString(server_type.type) << (settings.hilite ? hilite_none : ""); - if (server_type.type == ServerType::CUSTOM) + if (server_type.type == ServerType::Type::CUSTOM) { - settings.ostr << (settings.hilite ? hilite_identifier : "") << " " << backQuoteIfNeed(server_type.custom_name); + settings.ostr << " " << quoteString(server_type.custom_name); + } + + bool comma = false; + + if (!server_type.exclude_types.empty()) + { + settings.ostr << (settings.hilite ? hilite_keyword : "") + << " EXCEPT" << (settings.hilite ? hilite_none : ""); + + for (auto cur_type : server_type.exclude_types) + { + if (cur_type == ServerType::Type::CUSTOM) + continue; + + if (comma) + settings.ostr << ","; + else + comma = true; + + settings.ostr << (settings.hilite ? hilite_keyword : "") << " " + << ServerType::serverTypeToString(cur_type) << (settings.hilite ? hilite_none : ""); + } + + if (server_type.exclude_types.contains(ServerType::Type::CUSTOM)) + { + for (const auto & cur_name : server_type.exclude_custom_names) + { + if (comma) + settings.ostr << ","; + else + comma = true; + + settings.ostr << (settings.hilite ? hilite_keyword : "") << " " + << ServerType::serverTypeToString(ServerType::Type::CUSTOM) << (settings.hilite ? hilite_none : ""); + + settings.ostr << " " << quoteString(cur_name); + } + } } } diff --git a/src/Parsers/ParserSystemQuery.cpp b/src/Parsers/ParserSystemQuery.cpp index 40fc1acae69..ac3aa41048c 100644 --- a/src/Parsers/ParserSystemQuery.cpp +++ b/src/Parsers/ParserSystemQuery.cpp @@ -458,32 +458,71 @@ bool ParserSystemQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & if (!parseQueryWithOnCluster(res, pos, expected)) return false; - ServerType::Type current_type = ServerType::Type::END; - std::string current_custom_name; - - for (const auto & type : magic_enum::enum_values()) + auto parse_server_type = [&](ServerType::Type & type, std::string & custom_name) -> bool { - if (ParserKeyword{ServerType::serverTypeToString(type)}.ignore(pos, expected)) + type = ServerType::Type::END; + custom_name = ""; + + for (const auto & cur_type : magic_enum::enum_values()) { - current_type = type; - break; + if (ParserKeyword{ServerType::serverTypeToString(cur_type)}.ignore(pos, expected)) + { + type = cur_type; + break; + } + } + + if (type == ServerType::Type::END) + return false; + + if (type == ServerType::CUSTOM) + { + ASTPtr ast; + + if (!ParserStringLiteral{}.parse(pos, ast, expected)) + return false; + + custom_name = ast->as().value.get(); + } + + return true; + }; + + ServerType::Type base_type; + std::string base_custom_name; + + ServerType::Types exclude_type; + ServerType::CustomNames exclude_custom_names; + + if (!parse_server_type(base_type, base_custom_name)) + return false; + + if (ParserKeyword{"EXCEPT"}.ignore(pos, expected)) + { + if (base_type != ServerType::Type::QUERIES_ALL && + base_type != ServerType::Type::QUERIES_DEFAULT && + base_type != ServerType::Type::QUERIES_CUSTOM) + return false; + + ServerType::Type current_type; + std::string current_custom_name; + + while (true) + { + if (!exclude_type.empty() && !ParserToken(TokenType::Comma).ignore(pos, expected)) + break; + + if (!parse_server_type(current_type, current_custom_name)) + return false; + + exclude_type.insert(current_type); + + if (current_type == ServerType::Type::CUSTOM) + exclude_custom_names.insert(current_custom_name); } } - if (current_type == ServerType::Type::END) - return false; - - if (current_type == ServerType::CUSTOM) - { - ASTPtr ast; - - if (!ParserStringLiteral{}.parse(pos, ast, expected)) - return false; - - current_custom_name = ast->as().value.get(); - } - - res->server_type = ServerType(current_type, current_custom_name); + res->server_type = ServerType(base_type, base_custom_name, exclude_type, exclude_custom_names); break; } diff --git a/src/Server/ServerType.cpp b/src/Server/ServerType.cpp index 4952cd1bd24..4199a5fd042 100644 --- a/src/Server/ServerType.cpp +++ b/src/Server/ServerType.cpp @@ -42,12 +42,9 @@ const char * ServerType::serverTypeToString(ServerType::Type type) bool ServerType::shouldStart(Type server_type, const std::string & server_custom_name) const { - if (type == Type::QUERIES_ALL) - return true; - - if (type == Type::QUERIES_DEFAULT) + auto is_type_default = [](Type current_type) { - switch (server_type) + switch (current_type) { case Type::TCP: case Type::TCP_WITH_PROXY: @@ -64,21 +61,42 @@ bool ServerType::shouldStart(Type server_type, const std::string & server_custom default: return false; } + }; + + auto is_type_custom = [](Type current_type) + { + return current_type == Type::CUSTOM; + }; + + if (exclude_types.contains(Type::QUERIES_ALL)) + return false; + + if (exclude_types.contains(Type::QUERIES_DEFAULT) && is_type_default(server_type)) + return false; + + if (exclude_types.contains(Type::QUERIES_CUSTOM) && is_type_custom(server_type)) + return false; + + if (exclude_types.contains(server_type)) + { + if (server_type != Type::CUSTOM) + return false; + + if (exclude_custom_names.contains(server_custom_name)) + return false; } + if (type == Type::QUERIES_ALL) + return true; + + if (type == Type::QUERIES_DEFAULT) + return is_type_default(server_type); + if (type == Type::QUERIES_CUSTOM) - { - switch (server_type) - { - case Type::CUSTOM: - return true; - default: - return false; - } - } + return is_type_custom(server_type); if (type == Type::CUSTOM) - return server_type == type && server_custom_name == "protocols." + custom_name + ".port"; + return server_type == type && server_custom_name == custom_name; return server_type == type; } @@ -86,6 +104,7 @@ bool ServerType::shouldStart(Type server_type, const std::string & server_custom bool ServerType::shouldStop(const std::string & port_name) const { Type port_type; + std::string port_custom_name; if (port_name == "http_port") port_type = Type::HTTP; @@ -121,12 +140,19 @@ bool ServerType::shouldStop(const std::string & port_name) const port_type = Type::INTERSERVER_HTTPS; else if (port_name.starts_with("protocols.") && port_name.ends_with(".port")) + { port_type = Type::CUSTOM; + constexpr size_t protocols_size = std::string_view("protocols.").size(); + constexpr size_t ports_size = std::string_view(".ports").size(); + + port_custom_name = port_name.substr(protocols_size, port_name.size() - protocols_size - ports_size + 1); + } + else return false; - return shouldStart(port_type, port_name); + return shouldStart(port_type, port_custom_name); } } diff --git a/src/Server/ServerType.h b/src/Server/ServerType.h index 1fab492222a..bfbe692f5bd 100644 --- a/src/Server/ServerType.h +++ b/src/Server/ServerType.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB { @@ -28,8 +29,20 @@ public: END }; + using Types = std::unordered_set; + using CustomNames = std::unordered_set; + ServerType() = default; - explicit ServerType(Type type_, const std::string & custom_name_ = "") : type(type_), custom_name(custom_name_) {} + + explicit ServerType( + Type type_, + const std::string & custom_name_ = "", + const Types & exclude_types_ = {}, + const CustomNames exclude_custom_names_ = {}) + : type(type_), + custom_name(custom_name_), + exclude_types(exclude_types_), + exclude_custom_names(exclude_custom_names_) {} static const char * serverTypeToString(Type type); @@ -39,6 +52,9 @@ public: Type type; std::string custom_name; + + Types exclude_types; + CustomNames exclude_custom_names; }; } diff --git a/tests/integration/test_system_start_stop_listen/test.py b/tests/integration/test_system_start_stop_listen/test.py index 1925685af03..8a3081e0c15 100644 --- a/tests/integration/test_system_start_stop_listen/test.py +++ b/tests/integration/test_system_start_stop_listen/test.py @@ -143,3 +143,73 @@ def test_all_protocols(started_cluster): backup_node.query("SYSTEM START LISTEN ON CLUSTER default QUERIES ALL") assert_everything_works() + + +def test_except(started_cluster): + custom_client = Client(main_node.ip_address, 9001, command=cluster.client_bin_path) + assert_everything_works() + + # STOP LISTEN QUERIES ALL EXCEPT + main_node.query("SYSTEM STOP LISTEN QUERIES ALL EXCEPT MYSQL, CUSTOM 'tcp'") + assert "Connection refused" in main_node.query_and_get_error(QUERY) + custom_client.query(MYSQL_QUERY) + assert http_works() == False + assert http_works(8124) == False + + # START LISTEN QUERIES ALL EXCEPT + backup_node.query("SYSTEM START LISTEN ON CLUSTER default QUERIES ALL EXCEPT TCP") + assert "Connection refused" in main_node.query_and_get_error(QUERY) + custom_client.query(MYSQL_QUERY) + assert http_works() == True + assert http_works(8124) == True + backup_node.query("SYSTEM START LISTEN ON CLUSTER default QUERIES ALL") + + assert_everything_works() + + # STOP LISTEN QUERIES DEFAULT EXCEPT + main_node.query("SYSTEM STOP LISTEN QUERIES DEFAULT EXCEPT TCP") + main_node.query(QUERY) + assert "Connections to mysql failed" in custom_client.query_and_get_error( + MYSQL_QUERY + ) + custom_client.query(QUERY) + assert http_works() == False + assert http_works(8124) == True + + # START LISTEN QUERIES DEFAULT EXCEPT + backup_node.query( + "SYSTEM START LISTEN ON CLUSTER default QUERIES DEFAULT EXCEPT HTTP" + ) + main_node.query(QUERY) + main_node.query(MYSQL_QUERY) + custom_client.query(QUERY) + assert http_works() == False + assert http_works(8124) == True + + backup_node.query("SYSTEM START LISTEN ON CLUSTER default QUERIES ALL") + + assert_everything_works() + + # STOP LISTEN QUERIES CUSTOM EXCEPT + main_node.query("SYSTEM STOP LISTEN QUERIES CUSTOM EXCEPT CUSTOM 'tcp'") + main_node.query(QUERY) + custom_client.query(MYSQL_QUERY) + custom_client.query(QUERY) + assert http_works() == True + assert http_works(8124) == False + + main_node.query("SYSTEM STOP LISTEN QUERIES CUSTOM") + + # START LISTEN QUERIES DEFAULT EXCEPT + backup_node.query( + "SYSTEM START LISTEN ON CLUSTER default QUERIES CUSTOM EXCEPT CUSTOM 'tcp'" + ) + main_node.query(QUERY) + main_node.query(MYSQL_QUERY) + assert "Connection refused" in custom_client.query_and_get_error(QUERY) + assert http_works() == True + assert http_works(8124) == True + + backup_node.query("SYSTEM START LISTEN ON CLUSTER default QUERIES ALL") + + assert_everything_works() From c6fc31c1e36cd981c370fd98baf83cc670c7d584 Mon Sep 17 00:00:00 2001 From: Nikolay Degterinsky Date: Wed, 16 Aug 2023 23:06:42 +0000 Subject: [PATCH 2/2] Review suggestions --- src/Server/ServerType.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/Server/ServerType.cpp b/src/Server/ServerType.cpp index 883be4b0892..fb052e7d6e6 100644 --- a/src/Server/ServerType.cpp +++ b/src/Server/ServerType.cpp @@ -63,18 +63,13 @@ bool ServerType::shouldStart(Type server_type, const std::string & server_custom } }; - auto is_type_custom = [](Type current_type) - { - return current_type == Type::CUSTOM; - }; - if (exclude_types.contains(Type::QUERIES_ALL)) return false; if (exclude_types.contains(Type::QUERIES_DEFAULT) && is_type_default(server_type)) return false; - if (exclude_types.contains(Type::QUERIES_CUSTOM) && is_type_custom(server_type)) + if (exclude_types.contains(Type::QUERIES_CUSTOM) && server_type == Type::CUSTOM) return false; if (exclude_types.contains(server_type)) @@ -93,7 +88,7 @@ bool ServerType::shouldStart(Type server_type, const std::string & server_custom return is_type_default(server_type); if (type == Type::QUERIES_CUSTOM) - return is_type_custom(server_type); + return server_type == Type::CUSTOM; if (type == Type::CUSTOM) return server_type == type && server_custom_name == custom_name;