diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index df6b40cd347..c86a33ba60c 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -71,6 +71,7 @@ #include #include #include +#include #include #include "MetricsTransmitter.h" #include @@ -88,6 +89,7 @@ #include #include #include +#include #include @@ -1882,6 +1884,8 @@ void Server::createServers( return TCPServerConnectionFactory::Ptr(new TCPHandlerFactory(*this, false, false)); if (type == "tls") return TCPServerConnectionFactory::Ptr(new TLSHandlerFactory(*this)); + if (type == "proxy1") + return TCPServerConnectionFactory::Ptr(new ProxyV1HandlerFactory(*this)); if (type == "mysql") return TCPServerConnectionFactory::Ptr(new MySQLHandlerFactory(*this)); if (type == "postgres") @@ -1906,51 +1910,53 @@ void Server::createServers( for (const auto & protocol : protocols) { std::string prefix = protocol + "."; + std::unordered_set pset {prefix}; if (config.has(prefix + "host") && config.has(prefix + "port")) { - std::string port_name = prefix + "port"; std::string listen_host = prefix + "host"; bool is_secure = false; auto stack = std::make_unique(*this); while (true) { - if (!config.has(prefix + "type")) + // if there is no "type" - it's a reference to another protocol and this is just another endpoint + if (config.has(prefix + "type")) { - // misconfigured - lack of "type" - stack.reset(); - break; - } - - std::string type = config.getString(prefix + "type"); - if (type == "tls") - { - if (is_secure) + std::string type = config.getString(prefix + "type"); + if (type == "tls") { - // misconfigured - only one tls layer is allowed + if (is_secure) + { + // misconfigured - only one tls layer is allowed + stack.reset(); + break; + } + is_secure = true; + } + + TCPServerConnectionFactory::Ptr factory = createFactory(type); + if (!factory) + { + // misconfigured - protocol type doesn't exist stack.reset(); break; } - is_secure = true; + + stack->append(factory); + + if (!config.has(prefix + "impl")) + break; } - TCPServerConnectionFactory::Ptr factory = createFactory(type); - if (!factory) + prefix = "protocols." + config.getString(prefix + "impl") + "."; + + if (!pset.insert(prefix).second) { - // misconfigured - protocol "type" doesn't exist + // misconfigured - loop is detected stack.reset(); break; } - - stack->append(factory); - - if (!config.has(prefix + "impl")) - { - stack->append(createFactory("tcp")); - break; - } - prefix = "protocols." + config.getString(prefix + "impl"); } if (!stack) diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index 1fc88168b35..44b6cfdd628 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -109,6 +109,16 @@ TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::N { } +TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_) +: Poco::Net::TCPServerConnection(socket_) + , server(server_) + , tcp_server(tcp_server_) + , log(&Poco::Logger::get("TCPHandler")) + , forwarded_for(stack_data.forwarded_for) + , server_display_name(std::move(server_display_name_)) +{ +} + TCPHandler::~TCPHandler() { try diff --git a/src/Server/TCPHandler.h b/src/Server/TCPHandler.h index ea5fb2f9fe0..13c3c5f70c1 100644 --- a/src/Server/TCPHandler.h +++ b/src/Server/TCPHandler.h @@ -22,6 +22,7 @@ #include #include "IServer.h" +#include "Server/TCPProtocolStackData.h" #include "base/types.h" @@ -137,6 +138,7 @@ public: * Proxy-forwarded (original client) IP address is used for quota accounting if quota is keyed by forwarded IP. */ TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_); + TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_); ~TCPHandler() override; void run() override; diff --git a/src/Server/TCPHandlerFactory.h b/src/Server/TCPHandlerFactory.h index 354c886f4c0..fde04c6e0ab 100644 --- a/src/Server/TCPHandlerFactory.h +++ b/src/Server/TCPHandlerFactory.h @@ -3,6 +3,7 @@ #include #include #include +#include "Server/TCPProtocolStackData.h" #include #include #include @@ -53,6 +54,21 @@ public: return new DummyTCPHandler(socket); } } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) override + { + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + + return new TCPHandler(server, tcp_server, socket, stack_data, server_display_name); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } }; } diff --git a/src/Server/TCPProtocolStack.h b/src/Server/TCPProtocolStack.h index c72dfd98f53..21687898d45 100644 --- a/src/Server/TCPProtocolStack.h +++ b/src/Server/TCPProtocolStack.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -14,13 +15,24 @@ #include #include "Poco/Net/SSLManager.h" +#include +#include "Interpreters/Context.h" +#include "Server/TCPProtocolStackData.h" #include "base/types.h" namespace DB { +namespace ErrorCodes +{ + extern const int NETWORK_ERROR; + extern const int SOCKET_TIMEOUT; + extern const int CANNOT_READ_FROM_SOCKET; + extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED; +} + class TCPConnectionAccessor : public Poco::Net::TCPServerConnection { public: @@ -43,12 +55,16 @@ public: void run() override { + TCPProtocolStackData stack_data; + stack_data.socket = socket(); for (auto & factory : stack) { - std::unique_ptr connection(factory->createConnection(socket(), tcp_server)); + std::unique_ptr connection(factory->createConnection(socket(), tcp_server, stack_data)); connection->run(); - if (auto * accessor = dynamic_cast(connection.get()); accessor) - socket() = accessor->socket(); + if (stack_data.socket != socket()) + socket() = stack_data.socket; +// if (auto * accessor = dynamic_cast(connection.get()); accessor) + // socket() = accessor->socket(); } } }; @@ -99,17 +115,23 @@ public: -class TLSHandler : public TCPConnectionAccessor +class TLSHandler : public Poco::Net::TCPServerConnection //TCPConnectionAccessor { using StreamSocket = Poco::Net::StreamSocket; using SecureStreamSocket = Poco::Net::SecureStreamSocket; public: - explicit TLSHandler(const StreamSocket & socket) : TCPConnectionAccessor(socket) {} + explicit TLSHandler(const StreamSocket & socket, TCPProtocolStackData & stack_data_) + : Poco::Net::TCPServerConnection(socket) //TCPConnectionAccessor(socket) + , stack_data(stack_data_) + {} void run() override { socket() = SecureStreamSocket::attach(socket(), Poco::Net::SSLManager::instance().defaultServerContext()); + stack_data.socket = socket(); } +private: + TCPProtocolStackData & stack_data; }; @@ -134,12 +156,18 @@ public: server_display_name = server.config().getString("display_name", getFQDNOrHostName()); } - Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/) override + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override + { + TCPProtocolStackData stack_data; + return createConnection(socket, tcp_server, stack_data); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override { try { LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); - return new TLSHandler(socket); + return new TLSHandler(socket, stack_data); } catch (const Poco::Net::NetException &) { @@ -150,4 +178,164 @@ public: }; +class ProxyV1Handler : public Poco::Net::TCPServerConnection +{ + using StreamSocket = Poco::Net::StreamSocket; +public: + explicit ProxyV1Handler(const StreamSocket & socket, IServer & server_, TCPProtocolStackData & stack_data_) + : Poco::Net::TCPServerConnection(socket), server(server_), stack_data(stack_data_) {} + + void run() override + { + const auto & settings = server.context()->getSettingsRef(); + socket().setReceiveTimeout(settings.receive_timeout); + + std::string word; + bool eol; + + // Read PROXYv1 protocol header + // http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + + // read "PROXY" + if (!readWord(5, word, eol) || word != "PROXY" || eol) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + // read "TCP4" or "TCP6" or "UNKNOWN" + if (!readWord(7, word, eol)) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + if (word != "TCP4" && word != "TCP6" && word != "UNKNOWN") + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + if (word == "UNKNOWN" && eol) + return; + + if (eol) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + // read address + if (!readWord(39, word, eol) || eol) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + stack_data.forwarded_for = std::move(word); + + // read address + if (!readWord(39, word, eol) || eol) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + // read port + if (!readWord(5, word, eol) || eol) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + + // read port and "\r\n" + if (!readWord(5, word, eol) || !eol) + throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED); + } + +protected: + bool readWord(int max_len, std::string & word, bool & eol) + { + word.clear(); + eol = false; + + char ch = 0; + int n = 0; + bool is_cr = false; + try + { + for (++max_len; max_len > 0 || is_cr; --max_len) + { + n = socket().receiveBytes(&ch, 1); + if (n == 0) + { + socket().shutdown(); + return false; + } + if (n < 0) + break; + + if (is_cr) + return ch == 0x0A; + + if (ch == 0x0D) + { + is_cr = true; + eol = true; + continue; + } + + if (ch == ' ') + return true; + + word.push_back(ch); + } + } + catch (const Poco::Net::NetException & e) + { + throw NetException(e.displayText() + ", while reading from socket (" + socket().peerAddress().toString() + ")", ErrorCodes::NETWORK_ERROR); + } + catch (const Poco::TimeoutException &) + { + throw NetException(fmt::format("Timeout exceeded while reading from socket ({}, {} ms)", + socket().peerAddress().toString(), + socket().getReceiveTimeout().totalMilliseconds()), ErrorCodes::SOCKET_TIMEOUT); + } + catch (const Poco::IOException & e) + { + throw NetException(e.displayText() + ", while reading from socket (" + socket().peerAddress().toString() + ")", ErrorCodes::NETWORK_ERROR); + } + + if (n < 0) + throw NetException("Cannot read from socket (" + socket().peerAddress().toString() + ")", ErrorCodes::CANNOT_READ_FROM_SOCKET); + + return false; + } + +private: + IServer & server; + TCPProtocolStackData & stack_data; +}; + +class ProxyV1HandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + Poco::Logger * log; + std::string server_display_name; + + class DummyTCPHandler : public Poco::Net::TCPServerConnection + { + public: + using Poco::Net::TCPServerConnection::TCPServerConnection; + void run() override {} + }; + +public: + explicit ProxyV1HandlerFactory(IServer & server_) + : server(server_), log(&Poco::Logger::get("ProxyV1HandlerFactory")) + { + server_display_name = server.config().getString("display_name", getFQDNOrHostName()); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override + { + TCPProtocolStackData stack_data; + return createConnection(socket, tcp_server, stack_data); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override + { + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + return new ProxyV1Handler(socket, server, stack_data); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } +}; + } diff --git a/src/Server/TCPProtocolStackData.h b/src/Server/TCPProtocolStackData.h new file mode 100644 index 00000000000..bc90de8c678 --- /dev/null +++ b/src/Server/TCPProtocolStackData.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +namespace DB +{ + +struct TCPProtocolStackData +{ + Poco::Net::StreamSocket socket; + std::string forwarded_for; +}; + +} diff --git a/src/Server/TCPServerConnectionFactory.h b/src/Server/TCPServerConnectionFactory.h index 613f98352bd..ab9b0848ed7 100644 --- a/src/Server/TCPServerConnectionFactory.h +++ b/src/Server/TCPServerConnectionFactory.h @@ -1,6 +1,7 @@ #pragma once #include +#include "Server/TCPProtocolStackData.h" namespace Poco { @@ -23,5 +24,9 @@ public: /// Same as Poco::Net::TCPServerConnectionFactory except we can pass the TCPServer virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) = 0; + virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData &/* stack_data */) + { + return createConnection(socket, tcp_server); + } }; }