add PROXYv1 handler, add stack exchange data block, tuneup protocols config

This commit is contained in:
Yakov Olkhovskiy 2022-09-10 20:21:37 +00:00
parent 8a7fe2888a
commit 772bf050da
7 changed files with 274 additions and 32 deletions

View File

@ -71,6 +71,7 @@
#include <Dictionaries/registerDictionaries.h>
#include <Disks/registerDisks.h>
#include <Common/Config/ConfigReloader.h>
#include <Server/HTTP/HTTPServerConnectionFactory.h>
#include <Server/HTTPHandlerFactory.h>
#include "MetricsTransmitter.h"
#include <Common/StatusFile.h>
@ -88,6 +89,7 @@
#include <Server/HTTP/HTTPServer.h>
#include <Interpreters/AsynchronousInsertQueue.h>
#include <filesystem>
#include <unordered_set>
#include <Server/TCPProtocolStack.h>
@ -1882,6 +1884,8 @@ void Server::createServers(
return TCPServerConnectionFactory::Ptr(new TCPHandlerFactory(*this, false, false));
if (type == "tls")
return TCPServerConnectionFactory::Ptr(new TLSHandlerFactory(*this));
if (type == "proxy1")
return TCPServerConnectionFactory::Ptr(new ProxyV1HandlerFactory(*this));
if (type == "mysql")
return TCPServerConnectionFactory::Ptr(new MySQLHandlerFactory(*this));
if (type == "postgres")
@ -1906,23 +1910,19 @@ void Server::createServers(
for (const auto & protocol : protocols)
{
std::string prefix = protocol + ".";
std::unordered_set<std::string> pset {prefix};
if (config.has(prefix + "host") && config.has(prefix + "port"))
{
std::string port_name = prefix + "port";
std::string listen_host = prefix + "host";
bool is_secure = false;
auto stack = std::make_unique<TCPProtocolStackFactory>(*this);
while (true)
{
if (!config.has(prefix + "type"))
// if there is no "type" - it's a reference to another protocol and this is just another endpoint
if (config.has(prefix + "type"))
{
// misconfigured - lack of "type"
stack.reset();
break;
}
std::string type = config.getString(prefix + "type");
if (type == "tls")
{
@ -1938,7 +1938,7 @@ void Server::createServers(
TCPServerConnectionFactory::Ptr factory = createFactory(type);
if (!factory)
{
// misconfigured - protocol "type" doesn't exist
// misconfigured - protocol type doesn't exist
stack.reset();
break;
}
@ -1946,11 +1946,17 @@ void Server::createServers(
stack->append(factory);
if (!config.has(prefix + "impl"))
{
stack->append(createFactory("tcp"));
break;
}
prefix = "protocols." + config.getString(prefix + "impl");
prefix = "protocols." + config.getString(prefix + "impl") + ".";
if (!pset.insert(prefix).second)
{
// misconfigured - loop is detected
stack.reset();
break;
}
}
if (!stack)

View File

@ -109,6 +109,16 @@ TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::N
{
}
TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, tcp_server(tcp_server_)
, log(&Poco::Logger::get("TCPHandler"))
, forwarded_for(stack_data.forwarded_for)
, server_display_name(std::move(server_display_name_))
{
}
TCPHandler::~TCPHandler()
{
try

View File

@ -22,6 +22,7 @@
#include <Storages/MergeTree/ParallelReplicasReadingCoordinator.h>
#include "IServer.h"
#include "Server/TCPProtocolStackData.h"
#include "base/types.h"
@ -137,6 +138,7 @@ public:
* Proxy-forwarded (original client) IP address is used for quota accounting if quota is keyed by forwarded IP.
*/
TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_);
TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_);
~TCPHandler() override;
void run() override;

View File

@ -3,6 +3,7 @@
#include <Poco/Net/NetException.h>
#include <Poco/Util/LayeredConfiguration.h>
#include <Common/logger_useful.h>
#include "Server/TCPProtocolStackData.h"
#include <Server/IServer.h>
#include <Server/TCPHandler.h>
#include <Server/TCPServerConnectionFactory.h>
@ -53,6 +54,21 @@ public:
return new DummyTCPHandler(socket);
}
}
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) override
{
try
{
LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
return new TCPHandler(server, tcp_server, socket, stack_data, server_display_name);
}
catch (const Poco::Net::NetException &)
{
LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
return new DummyTCPHandler(socket);
}
}
};
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <cstring>
#include <memory>
#include <list>
@ -14,13 +15,24 @@
#include <Poco/Net/SecureStreamSocket.h>
#include "Poco/Net/SSLManager.h"
#include <Common/NetException.h>
#include "Interpreters/Context.h"
#include "Server/TCPProtocolStackData.h"
#include "base/types.h"
namespace DB
{
namespace ErrorCodes
{
extern const int NETWORK_ERROR;
extern const int SOCKET_TIMEOUT;
extern const int CANNOT_READ_FROM_SOCKET;
extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED;
}
class TCPConnectionAccessor : public Poco::Net::TCPServerConnection
{
public:
@ -43,12 +55,16 @@ public:
void run() override
{
TCPProtocolStackData stack_data;
stack_data.socket = socket();
for (auto & factory : stack)
{
std::unique_ptr<TCPServerConnection> connection(factory->createConnection(socket(), tcp_server));
std::unique_ptr<TCPServerConnection> connection(factory->createConnection(socket(), tcp_server, stack_data));
connection->run();
if (auto * accessor = dynamic_cast<TCPConnectionAccessor*>(connection.get()); accessor)
socket() = accessor->socket();
if (stack_data.socket != socket())
socket() = stack_data.socket;
// if (auto * accessor = dynamic_cast<TCPConnectionAccessor*>(connection.get()); accessor)
// socket() = accessor->socket();
}
}
};
@ -99,17 +115,23 @@ public:
class TLSHandler : public TCPConnectionAccessor
class TLSHandler : public Poco::Net::TCPServerConnection //TCPConnectionAccessor
{
using StreamSocket = Poco::Net::StreamSocket;
using SecureStreamSocket = Poco::Net::SecureStreamSocket;
public:
explicit TLSHandler(const StreamSocket & socket) : TCPConnectionAccessor(socket) {}
explicit TLSHandler(const StreamSocket & socket, TCPProtocolStackData & stack_data_)
: Poco::Net::TCPServerConnection(socket) //TCPConnectionAccessor(socket)
, stack_data(stack_data_)
{}
void run() override
{
socket() = SecureStreamSocket::attach(socket(), Poco::Net::SSLManager::instance().defaultServerContext());
stack_data.socket = socket();
}
private:
TCPProtocolStackData & stack_data;
};
@ -134,12 +156,18 @@ public:
server_display_name = server.config().getString("display_name", getFQDNOrHostName());
}
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/) override
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override
{
TCPProtocolStackData stack_data;
return createConnection(socket, tcp_server, stack_data);
}
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override
{
try
{
LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
return new TLSHandler(socket);
return new TLSHandler(socket, stack_data);
}
catch (const Poco::Net::NetException &)
{
@ -150,4 +178,164 @@ public:
};
class ProxyV1Handler : public Poco::Net::TCPServerConnection
{
using StreamSocket = Poco::Net::StreamSocket;
public:
explicit ProxyV1Handler(const StreamSocket & socket, IServer & server_, TCPProtocolStackData & stack_data_)
: Poco::Net::TCPServerConnection(socket), server(server_), stack_data(stack_data_) {}
void run() override
{
const auto & settings = server.context()->getSettingsRef();
socket().setReceiveTimeout(settings.receive_timeout);
std::string word;
bool eol;
// Read PROXYv1 protocol header
// http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
// read "PROXY"
if (!readWord(5, word, eol) || word != "PROXY" || eol)
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
// read "TCP4" or "TCP6" or "UNKNOWN"
if (!readWord(7, word, eol))
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
if (word != "TCP4" && word != "TCP6" && word != "UNKNOWN")
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
if (word == "UNKNOWN" && eol)
return;
if (eol)
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
// read address
if (!readWord(39, word, eol) || eol)
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
stack_data.forwarded_for = std::move(word);
// read address
if (!readWord(39, word, eol) || eol)
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
// read port
if (!readWord(5, word, eol) || eol)
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
// read port and "\r\n"
if (!readWord(5, word, eol) || !eol)
throw ParsingException("PROXY protocol violation", ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
}
protected:
bool readWord(int max_len, std::string & word, bool & eol)
{
word.clear();
eol = false;
char ch = 0;
int n = 0;
bool is_cr = false;
try
{
for (++max_len; max_len > 0 || is_cr; --max_len)
{
n = socket().receiveBytes(&ch, 1);
if (n == 0)
{
socket().shutdown();
return false;
}
if (n < 0)
break;
if (is_cr)
return ch == 0x0A;
if (ch == 0x0D)
{
is_cr = true;
eol = true;
continue;
}
if (ch == ' ')
return true;
word.push_back(ch);
}
}
catch (const Poco::Net::NetException & e)
{
throw NetException(e.displayText() + ", while reading from socket (" + socket().peerAddress().toString() + ")", ErrorCodes::NETWORK_ERROR);
}
catch (const Poco::TimeoutException &)
{
throw NetException(fmt::format("Timeout exceeded while reading from socket ({}, {} ms)",
socket().peerAddress().toString(),
socket().getReceiveTimeout().totalMilliseconds()), ErrorCodes::SOCKET_TIMEOUT);
}
catch (const Poco::IOException & e)
{
throw NetException(e.displayText() + ", while reading from socket (" + socket().peerAddress().toString() + ")", ErrorCodes::NETWORK_ERROR);
}
if (n < 0)
throw NetException("Cannot read from socket (" + socket().peerAddress().toString() + ")", ErrorCodes::CANNOT_READ_FROM_SOCKET);
return false;
}
private:
IServer & server;
TCPProtocolStackData & stack_data;
};
class ProxyV1HandlerFactory : public TCPServerConnectionFactory
{
private:
IServer & server;
Poco::Logger * log;
std::string server_display_name;
class DummyTCPHandler : public Poco::Net::TCPServerConnection
{
public:
using Poco::Net::TCPServerConnection::TCPServerConnection;
void run() override {}
};
public:
explicit ProxyV1HandlerFactory(IServer & server_)
: server(server_), log(&Poco::Logger::get("ProxyV1HandlerFactory"))
{
server_display_name = server.config().getString("display_name", getFQDNOrHostName());
}
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override
{
TCPProtocolStackData stack_data;
return createConnection(socket, tcp_server, stack_data);
}
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override
{
try
{
LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
return new ProxyV1Handler(socket, server, stack_data);
}
catch (const Poco::Net::NetException &)
{
LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
return new DummyTCPHandler(socket);
}
}
};
}

View File

@ -0,0 +1,15 @@
#pragma once
#include <string>
#include <Poco/Net/StreamSocket.h>
namespace DB
{
struct TCPProtocolStackData
{
Poco::Net::StreamSocket socket;
std::string forwarded_for;
};
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <Poco/SharedPtr.h>
#include "Server/TCPProtocolStackData.h"
namespace Poco
{
@ -23,5 +24,9 @@ public:
/// Same as Poco::Net::TCPServerConnectionFactory except we can pass the TCPServer
virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) = 0;
virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData &/* stack_data */)
{
return createConnection(socket, tcp_server);
}
};
}