Merge pull request #61793 from nickitat/keep_alive_max_reqs

Implement `max_keep_alive_requests` setting for server
This commit is contained in:
Nikita Taranov 2024-08-13 15:06:43 +00:00 committed by GitHub
commit dd3fa7c3b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 179 additions and 150 deletions

View File

@ -58,6 +58,10 @@ namespace Net
void setKeepAliveTimeout(Poco::Timespan keepAliveTimeout);
size_t getKeepAliveTimeout() const { return _keepAliveTimeout.totalSeconds(); }
size_t getMaxKeepAliveRequests() const { return _maxKeepAliveRequests; }
private:
bool _firstRequest;
Poco::Timespan _keepAliveTimeout;

View File

@ -19,11 +19,11 @@ namespace Poco {
namespace Net {
HTTPServerSession::HTTPServerSession(const StreamSocket& socket, HTTPServerParams::Ptr pParams):
HTTPSession(socket, pParams->getKeepAlive()),
_firstRequest(true),
_keepAliveTimeout(pParams->getKeepAliveTimeout()),
_maxKeepAliveRequests(pParams->getMaxKeepAliveRequests())
HTTPServerSession::HTTPServerSession(const StreamSocket & socket, HTTPServerParams::Ptr pParams)
: HTTPSession(socket, pParams->getKeepAlive())
, _firstRequest(true)
, _keepAliveTimeout(pParams->getKeepAliveTimeout())
, _maxKeepAliveRequests(pParams->getMaxKeepAliveRequests())
{
setTimeout(pParams->getTimeout());
}
@ -52,11 +52,12 @@ bool HTTPServerSession::hasMoreRequests()
}
else if (_maxKeepAliveRequests != 0 && getKeepAlive())
{
if (_maxKeepAliveRequests > 0)
--_maxKeepAliveRequests;
return buffered() > 0 || socket().poll(_keepAliveTimeout, Socket::SELECT_READ);
}
else return false;
if (_maxKeepAliveRequests > 0)
--_maxKeepAliveRequests;
return buffered() > 0 || socket().poll(_keepAliveTimeout, Socket::SELECT_READ);
}
else
return false;
}

View File

@ -1400,6 +1400,16 @@ The number of seconds that ClickHouse waits for incoming requests before closing
<keep_alive_timeout>10</keep_alive_timeout>
```
## max_keep_alive_requests {#max-keep-alive-requests}
Maximal number of requests through a single keep-alive connection until it will be closed by ClickHouse server. Default to 10000.
**Example**
``` xml
<max_keep_alive_requests>10</max_keep_alive_requests>
```
## listen_host {#listen_host}
Restriction on hosts that requests can come from. If you want the server to answer all of them, specify `::`.

View File

@ -27,7 +27,7 @@ std::string LibraryBridge::bridgeName() const
LibraryBridge::HandlerFactoryPtr LibraryBridge::getHandlerFactoryPtr(ContextPtr context) const
{
return std::make_shared<LibraryBridgeHandlerFactory>("LibraryRequestHandlerFactory", keep_alive_timeout, context);
return std::make_shared<LibraryBridgeHandlerFactory>("LibraryRequestHandlerFactory", context);
}
}

View File

@ -9,12 +9,10 @@ namespace DB
{
LibraryBridgeHandlerFactory::LibraryBridgeHandlerFactory(
const std::string & name_,
size_t keep_alive_timeout_,
ContextPtr context_)
: WithContext(context_)
, log(getLogger(name_))
, name(name_)
, keep_alive_timeout(keep_alive_timeout_)
{
}
@ -26,17 +24,17 @@ std::unique_ptr<HTTPRequestHandler> LibraryBridgeHandlerFactory::createRequestHa
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET)
{
if (uri.getPath() == "/extdict_ping")
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(getContext());
else if (uri.getPath() == "/catboost_ping")
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(getContext());
}
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
{
if (uri.getPath() == "/extdict_request")
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(getContext());
else if (uri.getPath() == "/catboost_request")
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(getContext());
}
return nullptr;

View File

@ -13,7 +13,6 @@ class LibraryBridgeHandlerFactory : public HTTPRequestHandlerFactory, WithContex
public:
LibraryBridgeHandlerFactory(
const std::string & name_,
size_t keep_alive_timeout_,
ContextPtr context_);
std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override;
@ -21,7 +20,6 @@ public:
private:
LoggerPtr log;
const std::string name;
const size_t keep_alive_timeout;
};
}

View File

@ -87,10 +87,8 @@ static void writeData(Block data, OutputFormatPtr format)
}
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("ExternalDictionaryLibraryBridgeRequestHandler"))
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("ExternalDictionaryLibraryBridgeRequestHandler"))
{
}
@ -137,7 +135,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
const String & dictionary_id = params.get("dictionary_id");
LOG_TRACE(log, "Library method: '{}', dictionary id: {}", method, dictionary_id);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
@ -374,10 +372,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
}
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("ExternalDictionaryLibraryBridgeExistsHandler"))
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("ExternalDictionaryLibraryBridgeExistsHandler"))
{
}
@ -401,7 +397,7 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
String res = library_handler ? "1" : "0";
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
LOG_TRACE(log, "Sending ping response: {} (dictionary id: {})", res, dictionary_id);
response.sendBuffer(res.data(), res.size());
}
@ -412,11 +408,8 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
}
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("CatBoostLibraryBridgeRequestHandler"))
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("CatBoostLibraryBridgeRequestHandler"))
{
}
@ -455,7 +448,7 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
const String & method = params.get("method");
LOG_TRACE(log, "Library method: '{}'", method);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
@ -617,10 +610,8 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
}
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("CatBoostLibraryBridgeExistsHandler"))
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("CatBoostLibraryBridgeExistsHandler"))
{
}
@ -634,7 +625,7 @@ void CatBoostLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & reque
String res = "1";
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
LOG_TRACE(log, "Sending ping response: {}", res);
response.sendBuffer(res.data(), res.size());
}

View File

@ -18,14 +18,13 @@ namespace DB
class ExternalDictionaryLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{
public:
ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit ExternalDictionaryLibraryBridgeRequestHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
static constexpr auto FORMAT = "RowBinary";
const size_t keep_alive_timeout;
LoggerPtr log;
};
@ -34,12 +33,11 @@ private:
class ExternalDictionaryLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit ExternalDictionaryLibraryBridgeExistsHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
const size_t keep_alive_timeout;
LoggerPtr log;
};
@ -63,12 +61,11 @@ private:
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit CatBoostLibraryBridgeRequestHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
const size_t keep_alive_timeout;
LoggerPtr log;
};
@ -77,12 +74,11 @@ private:
class CatBoostLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit CatBoostLibraryBridgeExistsHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
const size_t keep_alive_timeout;
LoggerPtr log;
};

View File

@ -202,10 +202,7 @@ void ODBCColumnsInfoHandler::handleRequest(HTTPServerRequest & request, HTTPServ
if (columns.empty())
throw Exception(ErrorCodes::UNKNOWN_TABLE, "Columns definition was not returned");
WriteBufferFromHTTPServerResponse out(
response,
request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD,
keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
writeStringBinary(columns.toString(), out);

View File

@ -15,18 +15,12 @@ namespace DB
class ODBCColumnsInfoHandler : public HTTPRequestHandler, WithContext
{
public:
ODBCColumnsInfoHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger("ODBCColumnsInfoHandler"))
, keep_alive_timeout(keep_alive_timeout_)
{
}
explicit ODBCColumnsInfoHandler(ContextPtr context_) : WithContext(context_), log(getLogger("ODBCColumnsInfoHandler")) { }
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
LoggerPtr log;
size_t keep_alive_timeout;
};
}

View File

@ -74,7 +74,7 @@ void IdentifierQuoteHandler::handleRequest(HTTPServerRequest & request, HTTPServ
auto identifier = getIdentifierQuote(std::move(connection));
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
writeStringBinary(identifier, out);

View File

@ -14,18 +14,12 @@ namespace DB
class IdentifierQuoteHandler : public HTTPRequestHandler, WithContext
{
public:
IdentifierQuoteHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger("IdentifierQuoteHandler"))
, keep_alive_timeout(keep_alive_timeout_)
{
}
explicit IdentifierQuoteHandler(ContextPtr context_) : WithContext(context_), log(getLogger("IdentifierQuoteHandler")) { }
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
LoggerPtr log;
size_t keep_alive_timeout;
};
}

View File

@ -132,7 +132,7 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
return;
}
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{

View File

@ -20,12 +20,10 @@ class ODBCHandler : public HTTPRequestHandler, WithContext
{
public:
ODBCHandler(
size_t keep_alive_timeout_,
ContextPtr context_,
const String & mode_)
: WithContext(context_)
, log(getLogger("ODBCHandler"))
, keep_alive_timeout(keep_alive_timeout_)
, mode(mode_)
{
}
@ -35,7 +33,6 @@ public:
private:
LoggerPtr log;
size_t keep_alive_timeout;
String mode;
static inline std::mutex mutex;

View File

@ -27,7 +27,7 @@ std::string ODBCBridge::bridgeName() const
ODBCBridge::HandlerFactoryPtr ODBCBridge::getHandlerFactoryPtr(ContextPtr context) const
{
return std::make_shared<ODBCBridgeHandlerFactory>("ODBCRequestHandlerFactory-factory", keep_alive_timeout, context);
return std::make_shared<ODBCBridgeHandlerFactory>("ODBCRequestHandlerFactory-factory", context);
}
}

View File

@ -9,11 +9,8 @@
namespace DB
{
ODBCBridgeHandlerFactory::ODBCBridgeHandlerFactory(const std::string & name_, size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger(name_))
, name(name_)
, keep_alive_timeout(keep_alive_timeout_)
ODBCBridgeHandlerFactory::ODBCBridgeHandlerFactory(const std::string & name_, ContextPtr context_)
: WithContext(context_), log(getLogger(name_)), name(name_)
{
}
@ -23,33 +20,33 @@ std::unique_ptr<HTTPRequestHandler> ODBCBridgeHandlerFactory::createRequestHandl
LOG_TRACE(log, "Request URI: {}", uri.toString());
if (uri.getPath() == "/ping" && request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET)
return std::make_unique<PingHandler>(keep_alive_timeout);
return std::make_unique<PingHandler>();
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
{
if (uri.getPath() == "/columns_info")
#if USE_ODBC
return std::make_unique<ODBCColumnsInfoHandler>(keep_alive_timeout, getContext());
return std::make_unique<ODBCColumnsInfoHandler>(getContext());
#else
return nullptr;
#endif
else if (uri.getPath() == "/identifier_quote")
#if USE_ODBC
return std::make_unique<IdentifierQuoteHandler>(keep_alive_timeout, getContext());
return std::make_unique<IdentifierQuoteHandler>(getContext());
#else
return nullptr;
#endif
else if (uri.getPath() == "/schema_allowed")
#if USE_ODBC
return std::make_unique<SchemaAllowedHandler>(keep_alive_timeout, getContext());
return std::make_unique<SchemaAllowedHandler>(getContext());
#else
return nullptr;
#endif
else if (uri.getPath() == "/write")
return std::make_unique<ODBCHandler>(keep_alive_timeout, getContext(), "write");
return std::make_unique<ODBCHandler>(getContext(), "write");
else
return std::make_unique<ODBCHandler>(keep_alive_timeout, getContext(), "read");
return std::make_unique<ODBCHandler>(getContext(), "read");
}
return nullptr;
}

View File

@ -17,14 +17,13 @@ namespace DB
class ODBCBridgeHandlerFactory : public HTTPRequestHandlerFactory, WithContext
{
public:
ODBCBridgeHandlerFactory(const std::string & name_, size_t keep_alive_timeout_, ContextPtr context_);
ODBCBridgeHandlerFactory(const std::string & name_, ContextPtr context_);
std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override;
private:
LoggerPtr log;
std::string name;
size_t keep_alive_timeout;
};
}

View File

@ -10,7 +10,7 @@ void PingHandler::handleRequest(HTTPServerRequest & /* request */, HTTPServerRes
{
try
{
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
const char * data = "Ok.\n";
response.sendBuffer(data, strlen(data));
}

View File

@ -9,11 +9,7 @@ namespace DB
class PingHandler : public HTTPRequestHandler
{
public:
explicit PingHandler(size_t keep_alive_timeout_) : keep_alive_timeout(keep_alive_timeout_) {}
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
size_t keep_alive_timeout;
};
}

View File

@ -88,7 +88,7 @@ void SchemaAllowedHandler::handleRequest(HTTPServerRequest & request, HTTPServer
bool result = isSchemaAllowed(std::move(connection));
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
writeBoolText(result, out);

View File

@ -17,18 +17,12 @@ class Context;
class SchemaAllowedHandler : public HTTPRequestHandler, WithContext
{
public:
SchemaAllowedHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger("SchemaAllowedHandler"))
, keep_alive_timeout(keep_alive_timeout_)
{
}
explicit SchemaAllowedHandler(ContextPtr context_) : WithContext(context_), log(getLogger("SchemaAllowedHandler")) { }
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
LoggerPtr log;
size_t keep_alive_timeout;
};
}

View File

@ -2428,6 +2428,7 @@ void Server::createServers(
Poco::Net::HTTPServerParams::Ptr http_params = new Poco::Net::HTTPServerParams;
http_params->setTimeout(settings.http_receive_timeout);
http_params->setKeepAliveTimeout(global_context->getServerSettings().keep_alive_timeout);
http_params->setMaxKeepAliveRequests(static_cast<int>(global_context->getServerSettings().max_keep_alive_requests));
Poco::Util::AbstractConfiguration::Keys protocols;
config.keys("protocols", protocols);

View File

@ -134,6 +134,7 @@ namespace DB
M(Bool, async_load_databases, false, "Enable asynchronous loading of databases and tables to speedup server startup. Queries to not yet loaded entity will be blocked until load is finished.", 0) \
M(Bool, display_secrets_in_show_and_select, false, "Allow showing secrets in SHOW and SELECT queries via a format setting and a grant", 0) \
M(Seconds, keep_alive_timeout, DEFAULT_HTTP_KEEP_ALIVE_TIMEOUT, "The number of seconds that ClickHouse waits for incoming requests before closing the connection.", 0) \
M(UInt64, max_keep_alive_requests, 10000, "The maximum number of requests handled via a single http keepalive connection before the server closes this connection.", 0) \
M(Seconds, replicated_fetches_http_connection_timeout, 0, "HTTP connection timeout for part fetch requests. Inherited from default profile `http_connection_timeout` if not set explicitly.", 0) \
M(Seconds, replicated_fetches_http_send_timeout, 0, "HTTP send timeout for part fetch requests. Inherited from default profile `http_send_timeout` if not set explicitly.", 0) \
M(Seconds, replicated_fetches_http_receive_timeout, 0, "HTTP receive timeout for fetch part requests. Inherited from default profile `http_receive_timeout` if not set explicitly.", 0) \

View File

@ -34,14 +34,20 @@ namespace ErrorCodes
extern const int RECEIVED_ERROR_TOO_MANY_REQUESTS;
}
void setResponseDefaultHeaders(HTTPServerResponse & response, size_t keep_alive_timeout)
void setResponseDefaultHeaders(HTTPServerResponse & response)
{
if (!response.getKeepAlive())
return;
Poco::Timespan timeout(keep_alive_timeout, 0);
if (timeout.totalSeconds())
response.set("Keep-Alive", "timeout=" + std::to_string(timeout.totalSeconds()));
const size_t keep_alive_timeout = response.getSession().getKeepAliveTimeout();
const size_t keep_alive_max_requests = response.getSession().getMaxKeepAliveRequests();
if (keep_alive_timeout)
{
if (keep_alive_max_requests)
response.set("Keep-Alive", fmt::format("timeout={}, max={}", keep_alive_timeout, keep_alive_max_requests));
else
response.set("Keep-Alive", fmt::format("timeout={}", keep_alive_timeout));
}
}
HTTPSessionPtr makeHTTPSession(

View File

@ -54,7 +54,7 @@ private:
using HTTPSessionPtr = std::shared_ptr<Poco::Net::HTTPClientSession>;
void setResponseDefaultHeaders(HTTPServerResponse & response, size_t keep_alive_timeout);
void setResponseDefaultHeaders(HTTPServerResponse & response);
/// Create session object to perform requests and set required parameters.
HTTPSessionPtr makeHTTPSession(

View File

@ -248,6 +248,8 @@ public:
void attachRequest(HTTPServerRequest * request_) { request = request_; }
const Poco::Net::HTTPServerSession & getSession() const { return session; }
private:
Poco::Net::HTTPServerSession & session;
HTTPServerRequest * request = nullptr;

View File

@ -30,7 +30,7 @@ void WriteBufferFromHTTPServerResponse::startSendHeaders()
if (add_cors_header)
response.set("Access-Control-Allow-Origin", "*");
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
std::stringstream header; //STYLE_CHECK_ALLOW_STD_STRING_STREAM
response.beginWrite(header);
@ -119,12 +119,10 @@ void WriteBufferFromHTTPServerResponse::nextImpl()
WriteBufferFromHTTPServerResponse::WriteBufferFromHTTPServerResponse(
HTTPServerResponse & response_,
bool is_http_method_head_,
UInt64 keep_alive_timeout_,
const ProfileEvents::Event & write_event_)
: HTTPWriteBuffer(response_.getSocket(), write_event_)
, response(response_)
, is_http_method_head(is_http_method_head_)
, keep_alive_timeout(keep_alive_timeout_)
{
}

View File

@ -29,7 +29,6 @@ public:
WriteBufferFromHTTPServerResponse(
HTTPServerResponse & response_,
bool is_http_method_head_,
UInt64 keep_alive_timeout_,
const ProfileEvents::Event & write_event_ = ProfileEvents::end());
~WriteBufferFromHTTPServerResponse() override;
@ -91,7 +90,6 @@ private:
bool is_http_method_head;
bool add_cors_header = false;
size_t keep_alive_timeout = 0;
bool initialized = false;

View File

@ -29,7 +29,7 @@ void sendExceptionToHTTPClient(
if (!out)
{
/// If nothing was sent yet.
WriteBufferFromHTTPServerResponse out_for_message{response, request.getMethod() == HTTPRequest::HTTP_HEAD, DEFAULT_HTTP_KEEP_ALIVE_TIMEOUT};
WriteBufferFromHTTPServerResponse out_for_message{response, request.getMethod() == HTTPRequest::HTTP_HEAD};
out_for_message.writeln(exception_message);
out_for_message.finalize();

View File

@ -266,7 +266,6 @@ void HTTPHandler::processQuery(
std::make_shared<WriteBufferFromHTTPServerResponse>(
response,
request.getMethod() == HTTPRequest::HTTP_HEAD,
context->getServerSettings().keep_alive_timeout.totalSeconds(),
write_event);
used_output.out = used_output.out_holder;
used_output.out_maybe_compressed = used_output.out_holder;
@ -558,7 +557,7 @@ try
if (!used_output.out_holder && !used_output.exception_is_written)
{
/// If nothing was sent yet and we don't even know if we must compress the response.
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD, DEFAULT_HTTP_KEEP_ALIVE_TIMEOUT).writeln(s);
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD).writeln(s);
}
else if (used_output.out_maybe_compressed)
{

View File

@ -122,7 +122,8 @@ static inline auto createHandlersFactoryFromConfig(
}
else if (handler_type == "prometheus")
{
main_handler_factory->addHandler(createPrometheusHandlerFactoryForHTTPRule(server, config, prefix + "." + key, async_metrics));
main_handler_factory->addHandler(
createPrometheusHandlerFactoryForHTTPRule(server, config, prefix + "." + key, async_metrics));
}
else if (handler_type == "replicas_status")
{

View File

@ -87,9 +87,8 @@ void InterserverIOHTTPHandler::handleRequest(HTTPServerRequest & request, HTTPSe
response.setChunkedTransferEncoding(true);
Output used_output;
const auto keep_alive_timeout = server.context()->getServerSettings().keep_alive_timeout.totalSeconds();
used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(
response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout, write_event);
response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, write_event);
auto finalize_output = [&]
{

View File

@ -95,7 +95,7 @@ public:
class PrometheusRequestHandler::ImplWithContext : public Impl
{
public:
explicit ImplWithContext(PrometheusRequestHandler & parent) : Impl(parent), default_settings(parent.server.context()->getSettingsRef()) { }
explicit ImplWithContext(PrometheusRequestHandler & parent) : Impl(parent), default_settings(server().context()->getSettingsRef()) { }
virtual void handlingRequestWithContext(HTTPServerRequest & request, HTTPServerResponse & response) = 0;
@ -353,7 +353,7 @@ void PrometheusRequestHandler::handleRequest(HTTPServerRequest & request, HTTPSe
if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
setResponseDefaultHeaders(response, config.keep_alive_timeout);
setResponseDefaultHeaders(response);
impl->beforeHandlingRequest(request);
impl->handleRequest(request, response);
@ -379,7 +379,7 @@ WriteBufferFromHTTPServerResponse & PrometheusRequestHandler::getOutputStream(HT
if (write_buffer_from_response)
return *write_buffer_from_response;
write_buffer_from_response = std::make_unique<WriteBufferFromHTTPServerResponse>(
response, http_method == HTTPRequest::HTTP_HEAD, config.keep_alive_timeout, write_event);
response, http_method == HTTPRequest::HTTP_HEAD, write_event);
return *write_buffer_from_response;
}
@ -399,7 +399,7 @@ void PrometheusRequestHandler::finalizeResponse(HTTPServerResponse & response)
if (write_buffer_from_response)
std::exchange(write_buffer_from_response, {})->finalize();
else
WriteBufferFromHTTPServerResponse{response, http_method == HTTPRequest::HTTP_HEAD, config.keep_alive_timeout, write_event}.finalize();
WriteBufferFromHTTPServerResponse{response, http_method == HTTPRequest::HTTP_HEAD, write_event}.finalize();
}
chassert(response_finalized && !write_buffer_from_response);
}

View File

@ -15,8 +15,11 @@ class WriteBufferFromHTTPServerResponse;
class PrometheusRequestHandler : public HTTPRequestHandler
{
public:
PrometheusRequestHandler(IServer & server_, const PrometheusRequestHandlerConfig & config_,
const AsynchronousMetrics & async_metrics_, std::shared_ptr<PrometheusMetricsWriter> metrics_writer_);
PrometheusRequestHandler(
IServer & server_,
const PrometheusRequestHandlerConfig & config_,
const AsynchronousMetrics & async_metrics_,
std::shared_ptr<PrometheusMetricsWriter> metrics_writer_);
~PrometheusRequestHandler() override;
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event_) override;

View File

@ -89,8 +89,7 @@ void ReplicasStatusHandler::handleRequest(HTTPServerRequest & request, HTTPServe
}
}
const auto & server_settings = getContext()->getServerSettings();
setResponseDefaultHeaders(response, server_settings.keep_alive_timeout.totalSeconds());
setResponseDefaultHeaders(response);
if (!ok)
{

View File

@ -35,10 +35,9 @@ namespace ErrorCodes
extern const int INVALID_CONFIG_PARAMETER;
}
static inline std::unique_ptr<WriteBuffer>
responseWriteBuffer(HTTPServerRequest & request, HTTPServerResponse & response, UInt64 keep_alive_timeout)
static inline std::unique_ptr<WriteBuffer> responseWriteBuffer(HTTPServerRequest & request, HTTPServerResponse & response)
{
auto buf = std::unique_ptr<WriteBuffer>(new WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD, keep_alive_timeout));
auto buf = std::unique_ptr<WriteBuffer>(new WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD));
/// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
String http_response_compression_methods = request.get("Accept-Encoding", "");
@ -91,8 +90,7 @@ static inline void trySendExceptionToClient(
void StaticRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & /*write_event*/)
{
auto keep_alive_timeout = server.context()->getServerSettings().keep_alive_timeout.totalSeconds();
auto out = responseWriteBuffer(request, response, keep_alive_timeout);
auto out = responseWriteBuffer(request, response);
try
{
@ -107,7 +105,7 @@ void StaticRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServer
"The Transfer-Encoding is not chunked and there "
"is no Content-Length header for POST request");
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTPStatus(status));
writeResponse(*out);
}

View File

@ -30,23 +30,20 @@ DashboardWebUIRequestHandler::DashboardWebUIRequestHandler(IServer & server_) :
BinaryWebUIRequestHandler::BinaryWebUIRequestHandler(IServer & server_) : server(server_) {}
JavaScriptWebUIRequestHandler::JavaScriptWebUIRequestHandler(IServer & server_) : server(server_) {}
static void handle(const IServer & server, HTTPServerRequest & request, HTTPServerResponse & response, std::string_view html)
static void handle(HTTPServerRequest & request, HTTPServerResponse & response, std::string_view html)
{
auto keep_alive_timeout = server.context()->getServerSettings().keep_alive_timeout.totalSeconds();
response.setContentType("text/html; charset=UTF-8");
if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK);
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD, keep_alive_timeout).write(html.data(), html.size());
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD).write(html.data(), html.size());
}
void PlayWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_play_htmlData), gresource_play_htmlSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_play_htmlData), gresource_play_htmlSize});
}
void DashboardWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
@ -64,23 +61,23 @@ void DashboardWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HT
static re2::RE2 lz_string_url = R"(https://[^\s"'`]+lz-string[^\s"'`]*\.js)";
RE2::Replace(&html, lz_string_url, "/js/lz-string.js");
handle(server, request, response, html);
handle(request, response, html);
}
void BinaryWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
}
void JavaScriptWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
{
if (request.getURI() == "/js/uplot.js")
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_uplot_jsData), gresource_uplot_jsSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_uplot_jsData), gresource_uplot_jsSize});
}
else if (request.getURI() == "/js/lz-string.js")
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_lz_string_jsData), gresource_lz_string_jsSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_lz_string_jsData), gresource_lz_string_jsSize});
}
else
{
@ -88,7 +85,7 @@ void JavaScriptWebUIRequestHandler::handleRequest(HTTPServerRequest & request, H
*response.send() << "Not found.\n";
}
handle(server, request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
}
}

View File

@ -0,0 +1,4 @@
<clickhouse>
<keep_alive_timeout>3600</keep_alive_timeout>
<max_keep_alive_requests>5</max_keep_alive_requests>
</clickhouse>

View File

@ -0,0 +1,55 @@
import logging
import pytest
import random
import requests
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance("node", main_configs=["configs/keep_alive_settings.xml"])
@pytest.fixture(scope="module")
def start_cluster():
try:
logging.info("Starting cluster...")
cluster.start()
logging.info("Cluster started")
yield cluster
finally:
cluster.shutdown()
def test_max_keep_alive_requests_on_user_side(start_cluster):
# In this test we have `keep_alive_timeout` set to one hour to never trigger connection reset by timeout, `max_keep_alive_requests` is set to 5.
# We expect server to close connection after each 5 requests. We detect connection reset by change in src port.
# So the first 5 requests should come from the same port, the following 5 requests should come from another port.
log_comments = []
for _ in range(10):
rand_id = random.randint(0, 1000000)
log_comment = f"test_requests_with_keep_alive_{rand_id}"
log_comments.append(log_comment)
log_comments = sorted(log_comments)
session = requests.Session()
for i in range(10):
session.get(
f"http://{node.ip_address}:8123/?query=select%201&log_comment={log_comments[i]}"
)
ports = node.query(
f"""
SYSTEM FLUSH LOGS;
SELECT port
FROM system.query_log
WHERE log_comment IN ({", ".join(f"'{comment}'" for comment in log_comments)}) AND type = 'QueryFinish'
ORDER BY log_comment
"""
).split("\n")[:-1]
expected = 5 * [ports[0]] + [ports[5]] * 5
assert ports == expected

View File

@ -1,6 +1,6 @@
< Connection: Keep-Alive
< Keep-Alive: timeout=10
< Keep-Alive: timeout=10, max=?
< Connection: Keep-Alive
< Keep-Alive: timeout=10
< Keep-Alive: timeout=10, max=?
< Connection: Keep-Alive
< Keep-Alive: timeout=10
< Keep-Alive: timeout=10, max=?

View File

@ -6,9 +6,10 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
URL="${CLICKHOUSE_PORT_HTTP_PROTO}://${CLICKHOUSE_HOST}:${CLICKHOUSE_PORT_HTTP}/"
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< "SELECT 1" 2>&1 | perl -lnE 'print if /Keep-Alive/';
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< " error here " 2>&1 | perl -lnE 'print if /Keep-Alive/';
${CLICKHOUSE_CURL} -vsS "${URL}"ping 2>&1 | perl -lnE 'print if /Keep-Alive/';
# the sed command here replaces the real number of left requests with a question mark, because it can vary and we don't really have control over it
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< "SELECT 1" 2>&1 | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -i 'keep-alive';
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< " error here " 2>&1 | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -i 'keep-alive';
${CLICKHOUSE_CURL} -vsS "${URL}"ping 2>&1 | perl -lnE 'print if /Keep-Alive/' | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -i 'keep-alive';
# no keep-alive:
${CLICKHOUSE_CURL} -vsS "${URL}"404/not/found/ 2>&1 | perl -lnE 'print if /Keep-Alive/';

View File

@ -2,11 +2,11 @@ HTTP/1.1 200 OK
Connection: Keep-Alive
Content-Type: text/tab-separated-values; charset=UTF-8
Transfer-Encoding: chunked
Keep-Alive: timeout=10
Keep-Alive: timeout=10, max=?
HTTP/1.1 200 OK
Connection: Keep-Alive
Content-Type: text/tab-separated-values; charset=UTF-8
Transfer-Encoding: chunked
Keep-Alive: timeout=10
Keep-Alive: timeout=10, max=?

View File

@ -4,8 +4,9 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
( ${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=SELECT%201";
${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=select+*+from+system.numbers+limit+1000000" ) | grep -v "Date:" | grep -v "X-ClickHouse-Server-Display-Name:" | grep -v "X-ClickHouse-Query-Id:" | grep -v "X-ClickHouse-Format:" | grep -v "X-ClickHouse-Timezone:"
# the sed command here replaces the real number of left requests with a question mark, because it can vary and we don't really have control over it
( ${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=SELECT%201" | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I';
${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=select+*+from+system.numbers+limit+1000000" ) | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -v "Date:" | grep -v "X-ClickHouse-Server-Display-Name:" | grep -v "X-ClickHouse-Query-Id:" | grep -v "X-ClickHouse-Format:" | grep -v "X-ClickHouse-Timezone:"
if [[ $(${CLICKHOUSE_CURL} -sS -X POST -I "${CLICKHOUSE_URL}&query=SELECT+1" | grep -c '411 Length Required') -ne 1 ]]; then
echo FAIL