mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-26 09:32:01 +00:00
Merge pull request #5419 from yurriy/mysql
Improvements of MySQL Wire Protocol
This commit is contained in:
commit
ead911efc2
@ -9,6 +9,7 @@ set(CLICKHOUSE_SERVER_SOURCES
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/Server.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/Server.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/TCPHandler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/TCPHandler.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandler.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandlerFactory.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
set(CLICKHOUSE_SERVER_LINK PRIVATE clickhouse_dictionaries clickhouse_common_io PUBLIC daemon PRIVATE clickhouse_storages_system clickhouse_functions clickhouse_aggregate_functions clickhouse_table_functions ${Poco_Net_LIBRARY})
|
set(CLICKHOUSE_SERVER_LINK PRIVATE clickhouse_dictionaries clickhouse_common_io PUBLIC daemon PRIVATE clickhouse_storages_system clickhouse_functions clickhouse_aggregate_functions clickhouse_table_functions ${Poco_Net_LIBRARY})
|
||||||
|
@ -9,12 +9,14 @@
|
|||||||
#include <Columns/ColumnVector.h>
|
#include <Columns/ColumnVector.h>
|
||||||
#include <Common/config_version.h>
|
#include <Common/config_version.h>
|
||||||
#include <Common/NetException.h>
|
#include <Common/NetException.h>
|
||||||
|
#include <Common/OpenSSLHelpers.h>
|
||||||
#include <Poco/Crypto/RSAKey.h>
|
#include <Poco/Crypto/RSAKey.h>
|
||||||
#include <Poco/Crypto/CipherFactory.h>
|
#include <Poco/Crypto/CipherFactory.h>
|
||||||
#include <Poco/Net/SecureStreamSocket.h>
|
#include <Poco/Net/SecureStreamSocket.h>
|
||||||
#include <Poco/Net/SSLManager.h>
|
#include <Poco/Net/SSLManager.h>
|
||||||
#include "MySQLHandler.h"
|
#include "MySQLHandler.h"
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <ext/scope_guard.h>
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
@ -25,17 +27,28 @@ using Poco::Net::SSLManager;
|
|||||||
|
|
||||||
namespace ErrorCodes
|
namespace ErrorCodes
|
||||||
{
|
{
|
||||||
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
|
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
|
||||||
extern const int UNKNOWN_EXCEPTION;
|
extern const int OPENSSL_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t MySQLHandler::last_connection_id = 0;
|
MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id)
|
||||||
|
: Poco::Net::TCPServerConnection(socket_)
|
||||||
|
, server(server_)
|
||||||
|
, log(&Poco::Logger::get("MySQLHandler"))
|
||||||
|
, connection_context(server.context())
|
||||||
|
, connection_id(connection_id)
|
||||||
|
, public_key(public_key)
|
||||||
|
, private_key(private_key)
|
||||||
|
{
|
||||||
|
server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
|
||||||
|
if (ssl_enabled)
|
||||||
|
server_capability_flags |= CLIENT_SSL;
|
||||||
|
}
|
||||||
|
|
||||||
void MySQLHandler::run()
|
void MySQLHandler::run()
|
||||||
{
|
{
|
||||||
connection_context = server.context();
|
connection_context = server.context();
|
||||||
connection_context.setDefaultFormat("MySQL");
|
connection_context.setDefaultFormat("MySQLWire");
|
||||||
|
|
||||||
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
|
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
|
||||||
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
|
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
|
||||||
@ -49,8 +62,7 @@ void MySQLHandler::run()
|
|||||||
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
|
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
|
||||||
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
|
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
|
||||||
*/
|
*/
|
||||||
Handshake handshake(connection_id, VERSION_STRING, scramble + '\0');
|
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING, scramble + '\0');
|
||||||
|
|
||||||
packet_sender->sendPacket<Handshake>(handshake, true);
|
packet_sender->sendPacket<Handshake>(handshake, true);
|
||||||
|
|
||||||
LOG_TRACE(log, "Sent handshake");
|
LOG_TRACE(log, "Sent handshake");
|
||||||
@ -78,15 +90,11 @@ void MySQLHandler::run()
|
|||||||
<< "\nauth_plugin_name: "
|
<< "\nauth_plugin_name: "
|
||||||
<< handshake_response.auth_plugin_name);
|
<< handshake_response.auth_plugin_name);
|
||||||
|
|
||||||
capabilities = handshake_response.capability_flags;
|
client_capability_flags = handshake_response.capability_flags;
|
||||||
if (!(capabilities & CLIENT_PROTOCOL_41))
|
if (!(client_capability_flags & CLIENT_PROTOCOL_41))
|
||||||
{
|
|
||||||
throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
||||||
}
|
if (!(client_capability_flags & CLIENT_PLUGIN_AUTH))
|
||||||
if (!(capabilities & CLIENT_PLUGIN_AUTH))
|
|
||||||
{
|
|
||||||
throw Exception("Required capability: CLIENT_PLUGIN_AUTH.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
throw Exception("Required capability: CLIENT_PLUGIN_AUTH.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
||||||
}
|
|
||||||
|
|
||||||
authenticate(handshake_response, scramble);
|
authenticate(handshake_response, scramble);
|
||||||
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
|
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
|
||||||
@ -165,7 +173,7 @@ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
|
|||||||
};
|
};
|
||||||
read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes.
|
read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes.
|
||||||
|
|
||||||
size_t payload_size = *reinterpret_cast<uint32_t *>(buf) & 0xFFFFFFu;
|
size_t payload_size = unalignedLoad<uint32_t>(buf) & 0xFFFFFFu;
|
||||||
LOG_TRACE(log, "payload size: " << payload_size);
|
LOG_TRACE(log, "payload size: " << payload_size);
|
||||||
|
|
||||||
if (payload_size == SSL_REQUEST_PAYLOAD_SIZE)
|
if (payload_size == SSL_REQUEST_PAYLOAD_SIZE)
|
||||||
@ -226,31 +234,19 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
|
|||||||
LOG_TRACE(log, "Authentication method match.");
|
LOG_TRACE(log, "Authentication method match.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto getOpenSSLError = []() -> String
|
|
||||||
{
|
|
||||||
BIO * mem = BIO_new(BIO_s_mem());
|
|
||||||
ERR_print_errors(mem);
|
|
||||||
char * buf = nullptr;
|
|
||||||
long size = BIO_get_mem_data(mem, &buf);
|
|
||||||
String errors_str(buf, size);
|
|
||||||
BIO_free(mem);
|
|
||||||
return errors_str;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (auth_response == "\1")
|
if (auth_response == "\1")
|
||||||
{
|
{
|
||||||
LOG_TRACE(log, "Client requests public key.");
|
LOG_TRACE(log, "Client requests public key.");
|
||||||
|
|
||||||
BIO * mem = BIO_new(BIO_s_mem());
|
BIO * mem = BIO_new(BIO_s_mem());
|
||||||
if (PEM_write_bio_RSA_PUBKEY(mem, public_key) != 1)
|
SCOPE_EXIT(BIO_free(mem));
|
||||||
|
if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1)
|
||||||
{
|
{
|
||||||
LOG_TRACE(log, "OpenSSL error:\n" << getOpenSSLError());
|
throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
throw Exception("Failed to write public key to memory.", ErrorCodes::UNKNOWN_EXCEPTION);
|
|
||||||
}
|
}
|
||||||
char * pem_buf = nullptr;
|
char * pem_buf = nullptr;
|
||||||
long pem_size = BIO_get_mem_data(mem, &pem_buf);
|
long pem_size = BIO_get_mem_data(mem, &pem_buf);
|
||||||
String pem(pem_buf, pem_size);
|
String pem(pem_buf, pem_size);
|
||||||
BIO_free(mem);
|
|
||||||
|
|
||||||
LOG_TRACE(log, "Key: " << pem);
|
LOG_TRACE(log, "Key: " << pem);
|
||||||
|
|
||||||
@ -271,17 +267,16 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
|
|||||||
* an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
|
* an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
|
||||||
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
|
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
|
||||||
*/
|
*/
|
||||||
if (!secure_connection && (!auth_response.empty() && auth_response != "\0"))
|
if (!secure_connection && !auth_response.empty() && auth_response != String("\0", 1))
|
||||||
{
|
{
|
||||||
LOG_TRACE(log, "Received nonempty password");
|
LOG_TRACE(log, "Received nonempty password");
|
||||||
auto ciphertext = reinterpret_cast<unsigned char *>(auth_response.data());
|
auto ciphertext = reinterpret_cast<unsigned char *>(auth_response.data());
|
||||||
|
|
||||||
unsigned char plaintext[RSA_size(private_key)];
|
unsigned char plaintext[RSA_size(&private_key)];
|
||||||
int plaintext_size = RSA_private_decrypt(auth_response.size(), ciphertext, plaintext, private_key, RSA_PKCS1_OAEP_PADDING);
|
int plaintext_size = RSA_private_decrypt(auth_response.size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING);
|
||||||
if (plaintext_size == -1)
|
if (plaintext_size == -1)
|
||||||
{
|
{
|
||||||
LOG_TRACE(log, "OpenSSL error:\n" << getOpenSSLError());
|
throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
throw Exception("Failed to decrypt.", ErrorCodes::UNKNOWN_EXCEPTION);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
password.resize(plaintext_size);
|
password.resize(plaintext_size);
|
||||||
@ -324,7 +319,7 @@ void MySQLHandler::comInitDB(const String & payload)
|
|||||||
String database = payload.substr(1);
|
String database = payload.substr(1);
|
||||||
LOG_DEBUG(log, "Setting current database to " << database);
|
LOG_DEBUG(log, "Setting current database to " << database);
|
||||||
connection_context.setCurrentDatabase(database);
|
connection_context.setCurrentDatabase(database);
|
||||||
packet_sender->sendPacket(OK_Packet(0, capabilities, 0, 0, 1), true);
|
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLHandler::comFieldList(const String & payload)
|
void MySQLHandler::comFieldList(const String & payload)
|
||||||
@ -340,12 +335,12 @@ void MySQLHandler::comFieldList(const String & payload)
|
|||||||
);
|
);
|
||||||
packet_sender->sendPacket(column_definition);
|
packet_sender->sendPacket(column_definition);
|
||||||
}
|
}
|
||||||
packet_sender->sendPacket(OK_Packet(0xfe, capabilities, 0, 0, 0), true);
|
packet_sender->sendPacket(OK_Packet(0xfe, client_capability_flags, 0, 0, 0), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLHandler::comPing()
|
void MySQLHandler::comPing()
|
||||||
{
|
{
|
||||||
packet_sender->sendPacket(OK_Packet(0x0, capabilities, 0, 0, 0), true);
|
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLHandler::comQuery(const String & payload)
|
void MySQLHandler::comQuery(const String & payload)
|
||||||
@ -357,7 +352,7 @@ void MySQLHandler::comQuery(const String & payload)
|
|||||||
ReadBufferFromMemory query(payload.data() + 1, payload.size() - 1);
|
ReadBufferFromMemory query(payload.data() + 1, payload.size() - 1);
|
||||||
executeQuery(query, *out, true, connection_context, set_content_type, nullptr);
|
executeQuery(query, *out, true, connection_context, set_content_type, nullptr);
|
||||||
if (!with_output)
|
if (!with_output)
|
||||||
packet_sender->sendPacket(OK_Packet(0x00, capabilities, 0, 0, 0), true);
|
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <Poco/Net/SecureStreamSocket.h>
|
#include <Poco/Net/SecureStreamSocket.h>
|
||||||
#include <Common/getFQDNOrHostName.h>
|
#include <Common/getFQDNOrHostName.h>
|
||||||
#include <Core/MySQLProtocol.h>
|
#include <Core/MySQLProtocol.h>
|
||||||
#include <openssl/evp.h>
|
#include <openssl/rsa.h>
|
||||||
#include "IServer.h"
|
#include "IServer.h"
|
||||||
|
|
||||||
|
|
||||||
@ -15,21 +15,7 @@ namespace DB
|
|||||||
class MySQLHandler : public Poco::Net::TCPServerConnection
|
class MySQLHandler : public Poco::Net::TCPServerConnection
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MySQLHandler(
|
MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id);
|
||||||
IServer & server_,
|
|
||||||
const Poco::Net::StreamSocket & socket_,
|
|
||||||
RSA * public_key,
|
|
||||||
RSA * private_key)
|
|
||||||
: Poco::Net::TCPServerConnection(socket_)
|
|
||||||
, server(server_)
|
|
||||||
, log(&Poco::Logger::get("MySQLHandler"))
|
|
||||||
, connection_context(server.context())
|
|
||||||
, connection_id(last_connection_id++)
|
|
||||||
, public_key(public_key)
|
|
||||||
, private_key(private_key)
|
|
||||||
{
|
|
||||||
log->setLevel("information");
|
|
||||||
}
|
|
||||||
|
|
||||||
void run() final;
|
void run() final;
|
||||||
|
|
||||||
@ -55,13 +41,13 @@ private:
|
|||||||
|
|
||||||
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
|
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
|
||||||
|
|
||||||
uint32_t connection_id = 0;
|
size_t connection_id = 0;
|
||||||
|
|
||||||
uint32_t capabilities;
|
size_t server_capability_flags;
|
||||||
|
size_t client_capability_flags;
|
||||||
|
|
||||||
static uint32_t last_connection_id;
|
RSA & public_key;
|
||||||
|
RSA & private_key;
|
||||||
RSA * public_key, * private_key;
|
|
||||||
|
|
||||||
std::shared_ptr<ReadBuffer> in;
|
std::shared_ptr<ReadBuffer> in;
|
||||||
std::shared_ptr<WriteBuffer> out;
|
std::shared_ptr<WriteBuffer> out;
|
||||||
|
124
dbms/programs/server/MySQLHandlerFactory.cpp
Normal file
124
dbms/programs/server/MySQLHandlerFactory.cpp
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
#include <Common/OpenSSLHelpers.h>
|
||||||
|
#include <Poco/Crypto/X509Certificate.h>
|
||||||
|
#include <Poco/Net/SSLManager.h>
|
||||||
|
#include <Poco/Net/TCPServerConnectionFactory.h>
|
||||||
|
#include <Poco/Util/Application.h>
|
||||||
|
#include <common/logger_useful.h>
|
||||||
|
#include <ext/scope_guard.h>
|
||||||
|
#include "IServer.h"
|
||||||
|
#include "MySQLHandler.h"
|
||||||
|
#include "MySQLHandlerFactory.h"
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int CANNOT_OPEN_FILE;
|
||||||
|
extern const int NO_ELEMENTS_IN_CONFIG;
|
||||||
|
extern const int OPENSSL_ERROR;
|
||||||
|
extern const int SYSTEM_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
MySQLHandlerFactory::MySQLHandlerFactory(IServer & server_)
|
||||||
|
: server(server_)
|
||||||
|
, log(&Logger::get("MySQLHandlerFactory"))
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
Poco::Net::SSLManager::instance().defaultServerContext();
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
LOG_INFO(log, "Failed to create SSL context. SSL will be disabled. Error: " << getCurrentExceptionMessage(false));
|
||||||
|
ssl_enabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reading rsa keys for SHA256 authentication plugin.
|
||||||
|
try
|
||||||
|
{
|
||||||
|
readRSAKeys();
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
LOG_WARNING(log, "Failed to read RSA keys. Error: " << getCurrentExceptionMessage(false));
|
||||||
|
generateRSAKeys();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySQLHandlerFactory::readRSAKeys()
|
||||||
|
{
|
||||||
|
const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config();
|
||||||
|
String certificateFileProperty = "openSSL.server.certificateFile";
|
||||||
|
String privateKeyFileProperty = "openSSL.server.privateKeyFile";
|
||||||
|
|
||||||
|
if (!config.has(certificateFileProperty))
|
||||||
|
throw Exception("Certificate file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
|
||||||
|
|
||||||
|
if (!config.has(privateKeyFileProperty))
|
||||||
|
throw Exception("Private key file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
|
||||||
|
|
||||||
|
{
|
||||||
|
String certificateFile = config.getString(certificateFileProperty);
|
||||||
|
FILE * fp = fopen(certificateFile.data(), "r");
|
||||||
|
if (fp == nullptr)
|
||||||
|
throw Exception("Cannot open certificate file: " + certificateFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
|
||||||
|
SCOPE_EXIT(fclose(fp));
|
||||||
|
|
||||||
|
X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr);
|
||||||
|
SCOPE_EXIT(X509_free(x509));
|
||||||
|
if (x509 == nullptr)
|
||||||
|
throw Exception("Failed to read PEM certificate from " + certificateFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
|
||||||
|
EVP_PKEY * p = X509_get_pubkey(x509);
|
||||||
|
if (p == nullptr)
|
||||||
|
throw Exception("Failed to get RSA key from X509. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
SCOPE_EXIT(EVP_PKEY_free(p));
|
||||||
|
|
||||||
|
public_key.reset(EVP_PKEY_get1_RSA(p));
|
||||||
|
if (public_key.get() == nullptr)
|
||||||
|
throw Exception("Failed to get RSA key from ENV_PKEY. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
String privateKeyFile = config.getString(privateKeyFileProperty);
|
||||||
|
|
||||||
|
FILE * fp = fopen(privateKeyFile.data(), "r");
|
||||||
|
if (fp == nullptr)
|
||||||
|
throw Exception ("Cannot open private key file " + privateKeyFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
|
||||||
|
SCOPE_EXIT(fclose(fp));
|
||||||
|
|
||||||
|
private_key.reset(PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr));
|
||||||
|
if (!private_key)
|
||||||
|
throw Exception("Failed to read RSA private key from " + privateKeyFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySQLHandlerFactory::generateRSAKeys()
|
||||||
|
{
|
||||||
|
LOG_INFO(log, "Generating new RSA key.");
|
||||||
|
public_key.reset(RSA_new());
|
||||||
|
if (!public_key)
|
||||||
|
throw Exception("Failed to allocate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
|
||||||
|
BIGNUM * e = BN_new();
|
||||||
|
if (!e)
|
||||||
|
throw Exception("Failed to allocate BIGNUM. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
SCOPE_EXIT(BN_free(e));
|
||||||
|
|
||||||
|
if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(public_key.get(), 2048, e, nullptr))
|
||||||
|
throw Exception("Failed to generate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
|
||||||
|
private_key.reset(RSAPrivateKey_dup(public_key.get()));
|
||||||
|
if (!private_key)
|
||||||
|
throw Exception("Failed to copy RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
Poco::Net::TCPServerConnection * MySQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket)
|
||||||
|
{
|
||||||
|
size_t connection_id = last_connection_id++;
|
||||||
|
LOG_TRACE(log, "MySQL connection. Id: " << connection_id << ". Address: " << socket.peerAddress().toString());
|
||||||
|
return new MySQLHandler(server, socket, *public_key, *private_key, ssl_enabled, connection_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,16 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Poco/Net/TCPServerConnectionFactory.h>
|
#include <Poco/Net/TCPServerConnectionFactory.h>
|
||||||
#include <Poco/Net/SSLManager.h>
|
#include <atomic>
|
||||||
#include <Poco/Crypto/X509Certificate.h>
|
#include <openssl/rsa.h>
|
||||||
#include <common/logger_useful.h>
|
|
||||||
#include "IServer.h"
|
#include "IServer.h"
|
||||||
#include "MySQLHandler.h"
|
|
||||||
|
|
||||||
namespace Poco
|
|
||||||
{
|
|
||||||
class Logger;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
@ -20,94 +13,27 @@ class MySQLHandlerFactory : public Poco::Net::TCPServerConnectionFactory
|
|||||||
private:
|
private:
|
||||||
IServer & server;
|
IServer & server;
|
||||||
Poco::Logger * log;
|
Poco::Logger * log;
|
||||||
RSA * public_key = nullptr, * private_key = nullptr;
|
|
||||||
|
|
||||||
|
struct RSADeleter
|
||||||
|
{
|
||||||
|
void operator()(RSA * ptr) { RSA_free(ptr); }
|
||||||
|
};
|
||||||
|
using RSAPtr = std::unique_ptr<RSA, RSADeleter>;
|
||||||
|
|
||||||
|
RSAPtr public_key;
|
||||||
|
RSAPtr private_key;
|
||||||
|
|
||||||
|
bool ssl_enabled = true;
|
||||||
|
|
||||||
|
std::atomic<size_t> last_connection_id = 0;
|
||||||
public:
|
public:
|
||||||
explicit MySQLHandlerFactory(IServer & server_)
|
explicit MySQLHandlerFactory(IServer & server_);
|
||||||
: server(server_), log(&Logger::get("MySQLHandlerFactory"))
|
|
||||||
{
|
|
||||||
/// Reading rsa keys for SHA256 authentication plugin.
|
|
||||||
const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config();
|
|
||||||
String certificateFileProperty = "openSSL.server.certificateFile";
|
|
||||||
String privateKeyFileProperty = "openSSL.server.privateKeyFile";
|
|
||||||
|
|
||||||
if (!config.has(certificateFileProperty))
|
void readRSAKeys();
|
||||||
{
|
|
||||||
LOG_INFO(log, "Certificate file is not set.");
|
|
||||||
generateRSAKeys();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!config.has(privateKeyFileProperty))
|
|
||||||
{
|
|
||||||
LOG_INFO(log, "Private key file is not set.");
|
|
||||||
generateRSAKeys();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
String certificateFile = config.getString(certificateFileProperty);
|
void generateRSAKeys();
|
||||||
FILE * fp = fopen(certificateFile.data(), "r");
|
|
||||||
if (fp == nullptr)
|
|
||||||
{
|
|
||||||
LOG_WARNING(log, "Cannot open certificate file: " << certificateFile << ".");
|
|
||||||
generateRSAKeys();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr);
|
|
||||||
EVP_PKEY * p = X509_get_pubkey(x509);
|
|
||||||
public_key = EVP_PKEY_get1_RSA(p);
|
|
||||||
X509_free(x509);
|
|
||||||
EVP_PKEY_free(p);
|
|
||||||
fclose(fp);
|
|
||||||
|
|
||||||
String privateKeyFile = config.getString(privateKeyFileProperty);
|
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override;
|
||||||
fp = fopen(privateKeyFile.data(), "r");
|
|
||||||
if (fp == nullptr)
|
|
||||||
{
|
|
||||||
LOG_WARNING(log, "Cannot open private key file " << privateKeyFile << ".");
|
|
||||||
generateRSAKeys();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
private_key = PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr);
|
|
||||||
fclose(fp);
|
|
||||||
}
|
|
||||||
|
|
||||||
void generateRSAKeys()
|
|
||||||
{
|
|
||||||
LOG_INFO(log, "Generating new RSA key.");
|
|
||||||
RSA * rsa = RSA_new();
|
|
||||||
if (rsa == nullptr)
|
|
||||||
{
|
|
||||||
throw Exception("Failed to allocate RSA key.", 1002);
|
|
||||||
}
|
|
||||||
BIGNUM * e = BN_new();
|
|
||||||
if (!e)
|
|
||||||
{
|
|
||||||
RSA_free(rsa);
|
|
||||||
throw Exception("Failed to allocate BIGNUM.", 1002);
|
|
||||||
}
|
|
||||||
if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(rsa, 2048, e, nullptr))
|
|
||||||
{
|
|
||||||
RSA_free(rsa);
|
|
||||||
BN_free(e);
|
|
||||||
throw Exception("Failed to generate RSA key.", 1002);
|
|
||||||
}
|
|
||||||
BN_free(e);
|
|
||||||
|
|
||||||
public_key = rsa;
|
|
||||||
private_key = RSAPrivateKey_dup(rsa);
|
|
||||||
}
|
|
||||||
|
|
||||||
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override
|
|
||||||
{
|
|
||||||
LOG_TRACE(log, "MySQL connection. Address: " << socket.peerAddress().toString());
|
|
||||||
return new MySQLHandler(server, socket, public_key, private_key);
|
|
||||||
}
|
|
||||||
|
|
||||||
~MySQLHandlerFactory() override
|
|
||||||
{
|
|
||||||
RSA_free(public_key);
|
|
||||||
RSA_free(private_key);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -427,6 +427,7 @@ namespace ErrorCodes
|
|||||||
extern const int BAD_TTL_EXPRESSION = 450;
|
extern const int BAD_TTL_EXPRESSION = 450;
|
||||||
extern const int BAD_TTL_FILE = 451;
|
extern const int BAD_TTL_FILE = 451;
|
||||||
extern const int SETTING_CONSTRAINT_VIOLATION = 452;
|
extern const int SETTING_CONSTRAINT_VIOLATION = 452;
|
||||||
|
extern const int OPENSSL_ERROR = 454;
|
||||||
|
|
||||||
extern const int KEEPER_EXCEPTION = 999;
|
extern const int KEEPER_EXCEPTION = 999;
|
||||||
extern const int POCO_EXCEPTION = 1000;
|
extern const int POCO_EXCEPTION = 1000;
|
||||||
|
18
dbms/src/Common/OpenSSLHelpers.cpp
Normal file
18
dbms/src/Common/OpenSSLHelpers.cpp
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#include "OpenSSLHelpers.h"
|
||||||
|
#include <ext/scope_guard.h>
|
||||||
|
#include <openssl/err.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
String getOpenSSLErrors()
|
||||||
|
{
|
||||||
|
BIO * mem = BIO_new(BIO_s_mem());
|
||||||
|
SCOPE_EXIT(BIO_free(mem));
|
||||||
|
ERR_print_errors(mem);
|
||||||
|
char * buf = nullptr;
|
||||||
|
long size = BIO_get_mem_data(mem, &buf);
|
||||||
|
return String(buf, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
12
dbms/src/Common/OpenSSLHelpers.h
Normal file
12
dbms/src/Common/OpenSSLHelpers.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Core/Types.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
|
||||||
|
String getOpenSSLErrors();
|
||||||
|
|
||||||
|
}
|
@ -1,18 +1,18 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <IO/ReadBuffer.h>
|
#include <Core/Types.h>
|
||||||
#include <IO/WriteBuffer.h>
|
|
||||||
#include <IO/copyData.h>
|
#include <IO/copyData.h>
|
||||||
|
#include <IO/ReadBuffer.h>
|
||||||
#include <IO/ReadBufferFromPocoSocket.h>
|
#include <IO/ReadBufferFromPocoSocket.h>
|
||||||
|
#include <IO/WriteBuffer.h>
|
||||||
#include <IO/WriteBufferFromPocoSocket.h>
|
#include <IO/WriteBufferFromPocoSocket.h>
|
||||||
#include <IO/WriteBufferFromString.h>
|
#include <IO/WriteBufferFromString.h>
|
||||||
#include <Core/Types.h>
|
#include <Poco/Logger.h>
|
||||||
#include <Poco/RandomStream.h>
|
|
||||||
#include <Poco/Net/StreamSocket.h>
|
#include <Poco/Net/StreamSocket.h>
|
||||||
|
#include <Poco/RandomStream.h>
|
||||||
|
#include <common/logger_useful.h>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <common/logger_useful.h>
|
|
||||||
#include <Poco/Logger.h>
|
|
||||||
|
|
||||||
/// Implementation of MySQL wire protocol
|
/// Implementation of MySQL wire protocol
|
||||||
|
|
||||||
@ -163,7 +163,6 @@ public:
|
|||||||
, out(&out)
|
, out(&out)
|
||||||
, log(&Poco::Logger::get(logger_name))
|
, log(&Poco::Logger::get(logger_name))
|
||||||
{
|
{
|
||||||
log->setLevel("information");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// For writing.
|
/// For writing.
|
||||||
@ -173,7 +172,6 @@ public:
|
|||||||
, out(&out)
|
, out(&out)
|
||||||
, log(&Poco::Logger::get(logger_name))
|
, log(&Poco::Logger::get(logger_name))
|
||||||
{
|
{
|
||||||
log->setLevel("information");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
String receivePacketPayload()
|
String receivePacketPayload()
|
||||||
@ -279,12 +277,11 @@ class Handshake : public WritePacket
|
|||||||
uint32_t status_flags;
|
uint32_t status_flags;
|
||||||
String auth_plugin_data;
|
String auth_plugin_data;
|
||||||
public:
|
public:
|
||||||
explicit Handshake(uint32_t connection_id, String server_version, String auth_plugin_data)
|
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_data)
|
||||||
: protocol_version(0xa)
|
: protocol_version(0xa)
|
||||||
, server_version(std::move(server_version))
|
, server_version(std::move(server_version))
|
||||||
, connection_id(connection_id)
|
, connection_id(connection_id)
|
||||||
, capability_flags(CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
|
, capability_flags(capability_flags)
|
||||||
| CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF | CLIENT_SSL)
|
|
||||||
, character_set(CharacterSet::utf8_general_ci)
|
, character_set(CharacterSet::utf8_general_ci)
|
||||||
, status_flags(0)
|
, status_flags(0)
|
||||||
, auth_plugin_data(auth_plugin_data)
|
, auth_plugin_data(auth_plugin_data)
|
||||||
|
@ -130,7 +130,7 @@ void registerOutputFormatXML(FormatFactory & factory);
|
|||||||
void registerOutputFormatODBCDriver(FormatFactory & factory);
|
void registerOutputFormatODBCDriver(FormatFactory & factory);
|
||||||
void registerOutputFormatODBCDriver2(FormatFactory & factory);
|
void registerOutputFormatODBCDriver2(FormatFactory & factory);
|
||||||
void registerOutputFormatNull(FormatFactory & factory);
|
void registerOutputFormatNull(FormatFactory & factory);
|
||||||
void registerOutputFormatMySQL(FormatFactory & factory);
|
void registerOutputFormatMySQLWire(FormatFactory & factory);
|
||||||
|
|
||||||
/// Input only formats.
|
/// Input only formats.
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ FormatFactory::FormatFactory()
|
|||||||
registerOutputFormatODBCDriver(*this);
|
registerOutputFormatODBCDriver(*this);
|
||||||
registerOutputFormatODBCDriver2(*this);
|
registerOutputFormatODBCDriver2(*this);
|
||||||
registerOutputFormatNull(*this);
|
registerOutputFormatNull(*this);
|
||||||
registerOutputFormatMySQL(*this);
|
registerOutputFormatMySQLWire(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
#include <DataStreams/MySQLBlockOutputStream.h>
|
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
|
|
||||||
void registerOutputFormatMySQL(FormatFactory & factory)
|
|
||||||
{
|
|
||||||
factory.registerOutputFormat("MySQL", [](
|
|
||||||
WriteBuffer & buf,
|
|
||||||
const Block & sample,
|
|
||||||
const Context & context,
|
|
||||||
const FormatSettings &)
|
|
||||||
{
|
|
||||||
return std::make_shared<MySQLBlockOutputStream>(buf, sample, const_cast<Context &>(context));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
#include "MySQLBlockOutputStream.h"
|
#include "MySQLWireBlockOutputStream.h"
|
||||||
#include <Core/MySQLProtocol.h>
|
#include <Core/MySQLProtocol.h>
|
||||||
#include <Interpreters/ProcessList.h>
|
#include <Interpreters/ProcessList.h>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
@ -9,15 +9,15 @@ namespace DB
|
|||||||
|
|
||||||
using namespace MySQLProtocol;
|
using namespace MySQLProtocol;
|
||||||
|
|
||||||
MySQLBlockOutputStream::MySQLBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context)
|
MySQLWireBlockOutputStream::MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context)
|
||||||
: header(header)
|
: header(header)
|
||||||
, context(context)
|
, context(context)
|
||||||
, packet_sender(new PacketSender(buf, context.sequence_id, "MySQLBlockOutputStream"))
|
, packet_sender(new PacketSender(buf, context.sequence_id, "MySQLWireBlockOutputStream"))
|
||||||
{
|
{
|
||||||
packet_sender->max_packet_size = context.max_packet_size;
|
packet_sender->max_packet_size = context.max_packet_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLBlockOutputStream::writePrefix()
|
void MySQLWireBlockOutputStream::writePrefix()
|
||||||
{
|
{
|
||||||
if (header.columns() == 0)
|
if (header.columns() == 0)
|
||||||
return;
|
return;
|
||||||
@ -26,8 +26,7 @@ void MySQLBlockOutputStream::writePrefix()
|
|||||||
|
|
||||||
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
|
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
|
||||||
{
|
{
|
||||||
ColumnDefinition column_definition(column.name, CharacterSet::binary, std::numeric_limits<uint32_t>::max(),
|
ColumnDefinition column_definition(column.name, CharacterSet::binary, 0, ColumnType::MYSQL_TYPE_STRING, 0, 0);
|
||||||
ColumnType::MYSQL_TYPE_STRING, 0, 0);
|
|
||||||
packet_sender->sendPacket(column_definition);
|
packet_sender->sendPacket(column_definition);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +36,7 @@ void MySQLBlockOutputStream::writePrefix()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLBlockOutputStream::write(const Block & block)
|
void MySQLWireBlockOutputStream::write(const Block & block)
|
||||||
{
|
{
|
||||||
size_t rows = block.rows();
|
size_t rows = block.rows();
|
||||||
|
|
||||||
@ -57,7 +56,7 @@ void MySQLBlockOutputStream::write(const Block & block)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLBlockOutputStream::writeSuffix()
|
void MySQLWireBlockOutputStream::writeSuffix()
|
||||||
{
|
{
|
||||||
QueryStatus * process_list_elem = context.getProcessListElement();
|
QueryStatus * process_list_elem = context.getProcessListElement();
|
||||||
CurrentThread::finalizePerformanceCounters();
|
CurrentThread::finalizePerformanceCounters();
|
||||||
@ -79,7 +78,7 @@ void MySQLBlockOutputStream::writeSuffix()
|
|||||||
packet_sender->sendPacket(EOF_Packet(0, 0), true);
|
packet_sender->sendPacket(EOF_Packet(0, 0), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySQLBlockOutputStream::flush()
|
void MySQLWireBlockOutputStream::flush()
|
||||||
{
|
{
|
||||||
packet_sender->out->next();
|
packet_sender->out->next();
|
||||||
}
|
}
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "IBlockOutputStream.h"
|
|
||||||
#include <Core/MySQLProtocol.h>
|
#include <Core/MySQLProtocol.h>
|
||||||
|
#include <DataStreams/IBlockOutputStream.h>
|
||||||
#include <Formats/FormatFactory.h>
|
#include <Formats/FormatFactory.h>
|
||||||
#include <Formats/FormatSettings.h>
|
#include <Formats/FormatSettings.h>
|
||||||
#include <Interpreters/Context.h>
|
#include <Interpreters/Context.h>
|
||||||
@ -11,10 +11,10 @@ namespace DB
|
|||||||
|
|
||||||
/** Interface for writing rows in MySQL Client/Server Protocol format.
|
/** Interface for writing rows in MySQL Client/Server Protocol format.
|
||||||
*/
|
*/
|
||||||
class MySQLBlockOutputStream : public IBlockOutputStream
|
class MySQLWireBlockOutputStream : public IBlockOutputStream
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MySQLBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context);
|
MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context);
|
||||||
|
|
||||||
Block getHeader() const { return header; }
|
Block getHeader() const { return header; }
|
||||||
|
|
||||||
@ -31,6 +31,6 @@ private:
|
|||||||
FormatSettings format_settings;
|
FormatSettings format_settings;
|
||||||
};
|
};
|
||||||
|
|
||||||
using MySQLBlockOutputStreamPtr = std::shared_ptr<MySQLBlockOutputStream>;
|
using MySQLWireBlockOutputStreamPtr = std::shared_ptr<MySQLWireBlockOutputStream>;
|
||||||
|
|
||||||
}
|
}
|
19
dbms/src/Formats/MySQLWireFormat.cpp
Normal file
19
dbms/src/Formats/MySQLWireFormat.cpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#include <Formats/MySQLWireBlockOutputStream.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
void registerOutputFormatMySQLWire(FormatFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerOutputFormat("MySQLWire", [](
|
||||||
|
WriteBuffer & buf,
|
||||||
|
const Block & sample,
|
||||||
|
const Context & context,
|
||||||
|
const FormatSettings &)
|
||||||
|
{
|
||||||
|
return std::make_shared<MySQLWireBlockOutputStream>(buf, sample, const_cast<Context &>(context));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user