ISSUES-4006 split mysql protocol

This commit is contained in:
zhang2014 2020-08-13 20:41:36 +08:00
parent 34f4c8972e
commit 96bd3ac34b
26 changed files with 95 additions and 127 deletions

View File

@ -73,13 +73,13 @@ Native41::Native41(const String & password, const String & auth_plugin_data)
void Native41::authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool, const Poco::Net::SocketAddress & address)
std::shared_ptr<PacketEndpoint> packet_endpoint, bool, const Poco::Net::SocketAddress & address)
{
if (!auth_response)
{
packet_sender->sendPacket(AuthSwitchRequest(getName(), scramble), true);
packet_endpoint->sendPacket(AuthSwitchRequest(getName(), scramble), true);
AuthSwitchResponse response;
packet_sender->receivePacket(response);
packet_endpoint->receivePacket(response);
auth_response = response.value;
}
@ -134,18 +134,18 @@ Sha256Password::Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logg
void Sha256Password::authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address)
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address)
{
if (!auth_response)
{
packet_sender->sendPacket(AuthSwitchRequest(getName(), scramble), true);
packet_endpoint->sendPacket(AuthSwitchRequest(getName(), scramble), true);
if (packet_sender->in->eof())
if (packet_endpoint->in->eof())
throw Exception("Client doesn't support authentication method " + getName() + " used by ClickHouse. Specifying user password using 'password_double_sha1_hex' may fix the problem.",
ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
AuthSwitchResponse response;
packet_sender->receivePacket(response);
packet_endpoint->receivePacket(response);
auth_response = response.value;
LOG_TRACE(log, "Authentication method mismatch.");
}
@ -174,11 +174,11 @@ void Sha256Password::authenticate(
LOG_TRACE(log, "Key: {}", pem);
AuthMoreData data(pem);
packet_sender->sendPacket(data, true);
packet_endpoint->sendPacket(data, true);
sent_public_key = true;
AuthSwitchResponse response;
packet_sender->receivePacket(response);
packet_endpoint->receivePacket(response);
auth_response = response.value;
}
else

View File

@ -33,7 +33,7 @@ public:
virtual void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
};
/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
@ -50,7 +50,7 @@ public:
void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override;
std::shared_ptr<PacketEndpoint> packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override;
private:
String scramble;
@ -70,7 +70,7 @@ public:
void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) override;
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override;
private:
RSA & public_key;

View File

@ -1,6 +1,6 @@
#include <Core/MySQL/IMySQLReadPacket.h>
#include <sstream>
#include <Core/MySQL/PacketPayloadReadBuffer.h>
#include <IO/MySQLPacketPayloadReadBuffer.h>
#include <IO/LimitReadBuffer.h>
namespace DB
@ -16,7 +16,7 @@ namespace MySQLProtocol
void IMySQLReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
{
PacketPayloadReadBuffer payload(in, sequence_id);
MySQLPacketPayloadReadBuffer payload(in, sequence_id);
payload.next();
readPayloadImpl(payload);
if (!payload.eof())

View File

@ -1,5 +1,5 @@
#include <Core/MySQL/IMySQLWritePacket.h>
#include <Core/MySQL/PacketPayloadWriteBuffer.h>
#include <IO/MySQLPacketPayloadWriteBuffer.h>
#include <sstream>
namespace DB
@ -10,7 +10,7 @@ namespace MySQLProtocol
void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
{
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
MySQLPacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf);
buf.next();
if (buf.remainingPayloadSize())

View File

@ -5,7 +5,7 @@
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Core/MySQL/PacketsReplication.h>
#include <Core/MySQLReplication.h>
#include <Core/MySQL/MySQLReplication.h>
namespace DB
{
@ -53,7 +53,7 @@ void MySQLClient::connect()
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketEndpoint>(*in, *out, seq);
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, seq);
handshake();
}
@ -71,7 +71,7 @@ void MySQLClient::disconnect()
void MySQLClient::handshake()
{
Handshake handshake;
packet_sender->receivePacket(handshake);
packet_endpoint->receivePacket(handshake);
if (handshake.auth_plugin_name != mysql_native_password)
{
throw Exception(
@ -83,12 +83,12 @@ void MySQLClient::handshake()
String auth_plugin_data = native41.getAuthPluginData();
HandshakeResponse handshake_response(
client_capability_flags, max_packet_size, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_sender->sendPacket<HandshakeResponse>(handshake_response, true);
client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true);
ResponsePacket packet_response(client_capability_flags, true);
packet_sender->receivePacket(packet_response);
packet_sender->resetSequenceId();
packet_endpoint->receivePacket(packet_response);
packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR)
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
@ -99,10 +99,10 @@ void MySQLClient::handshake()
void MySQLClient::writeCommand(char command, String query)
{
WriteCommand write_command(command, query);
packet_sender->sendPacket<WriteCommand>(write_command, true);
packet_endpoint->sendPacket<WriteCommand>(write_command, true);
ResponsePacket packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response);
packet_endpoint->receivePacket(packet_response);
switch (packet_response.getType())
{
case PACKET_ERR:
@ -112,17 +112,17 @@ void MySQLClient::writeCommand(char command, String query)
default:
break;
}
packet_sender->resetSequenceId();
packet_endpoint->resetSequenceId();
}
void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
{
RegisterSlave register_slave(slave_id);
packet_sender->sendPacket<RegisterSlave>(register_slave, true);
packet_endpoint->sendPacket<RegisterSlave>(register_slave, true);
ResponsePacket packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response);
packet_sender->resetSequenceId();
packet_endpoint->receivePacket(packet_response);
packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR)
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
@ -150,12 +150,12 @@ void MySQLClient::startBinlogDump(UInt32 slave_id, String replicate_db, String b
binlog_pos = binlog_pos < 4 ? 4 : binlog_pos;
BinlogDump binlog_dump(binlog_pos, binlog_file_name, slave_id);
packet_sender->sendPacket<BinlogDump>(binlog_dump, true);
packet_endpoint->sendPacket<BinlogDump>(binlog_dump, true);
}
BinlogEventPtr MySQLClient::readOneBinlogEvent(UInt64 milliseconds)
{
if (packet_sender->tryReceivePacket(replication, milliseconds))
if (packet_endpoint->tryReceivePacket(replication, milliseconds))
return replication.readOneEvent();
return {};

View File

@ -1,6 +1,6 @@
#pragma once
#include <Core/Types.h>
#include <Core/MySQLReplication.h>
#include <Core/MySQL/MySQLReplication.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromPocoSocket.h>
@ -44,7 +44,6 @@ private:
uint8_t seq = 0;
const UInt8 charset_utf8 = 33;
const UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH;
const String mysql_native_password = "mysql_native_password";
MySQLFlavor replication;
@ -52,7 +51,7 @@ private:
std::shared_ptr<WriteBuffer> out;
std::unique_ptr<Poco::Net::StreamSocket> socket;
std::optional<Poco::Net::SocketAddress> address;
std::shared_ptr<PacketEndpoint> packet_sender;
std::shared_ptr<PacketEndpoint> packet_endpoint;
void handshake();
void registerSlaveOnMaster(UInt32 slave_id);

View File

@ -18,9 +18,9 @@ PacketEndpoint::PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & s
{
}
PacketPayloadReadBuffer PacketEndpoint::getPayload()
MySQLPacketPayloadReadBuffer PacketEndpoint::getPayload()
{
return PacketPayloadReadBuffer(*in, sequence_id);
return MySQLPacketPayloadReadBuffer(*in, sequence_id);
}
void PacketEndpoint::receivePacket(IMySQLReadPacket & packet)

View File

@ -4,7 +4,7 @@
#include <IO/WriteBuffer.h>
#include "IMySQLReadPacket.h"
#include "IMySQLWritePacket.h"
#include "PacketPayloadReadBuffer.h"
#include "IO/MySQLPacketPayloadReadBuffer.h"
namespace DB
{
@ -21,7 +21,6 @@ public:
uint8_t & sequence_id;
ReadBuffer * in;
WriteBuffer * out;
size_t max_packet_size = MAX_PACKET_LENGTH;
/// For writing.
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
@ -29,7 +28,7 @@ public:
/// For reading and writing.
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
PacketPayloadReadBuffer getPayload();
MySQLPacketPayloadReadBuffer getPayload();
void receivePacket(IMySQLReadPacket & packet);

View File

@ -12,6 +12,8 @@ namespace MySQLProtocol
namespace Generic
{
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
enum StatusFlags
{
SERVER_SESSION_STATE_CHANGED = 0x4000

View File

@ -59,13 +59,4 @@ void BinlogDump::writePayloadImpl(WriteBuffer & buffer) const
}
}
}
namespace DB::MySQLProtocol
{
extern const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
}

View File

@ -1,7 +1,7 @@
#pragma once
#include <Core/MySQL/PacketPayloadReadBuffer.h>
#include <Core/MySQL/PacketPayloadWriteBuffer.h>
#include <IO/MySQLPacketPayloadReadBuffer.h>
#include <IO/MySQLPacketPayloadWriteBuffer.h>
#include <Core/MySQL/PacketEndpoint.h>
/// Implementation of MySQL wire protocol.

View File

@ -1,6 +1,6 @@
#include <string>
#include <Core/MySQLClient.h>
#include <Core/MySQL/MySQLClient.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h>
@ -27,7 +27,7 @@ int main(int argc, char ** argv)
String database;
UInt8 charset_utf8 = 33;
UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH;
UInt32 max_packet_size = MAX_PACKET_LENGTH;
String mysql_native_password = "mysql_native_password";
UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH

View File

@ -18,7 +18,7 @@
#endif
#if USE_MYSQL
# include <Core/MySQLClient.h>
# include <Core/MySQL/MySQLClient.h>
# include <Databases/MySQL/DatabaseConnectionMySQL.h>
# include <Databases/MySQL/MaterializeMySQLSettings.h>
# include <Databases/MySQL/DatabaseMaterializeMySQL.h>

View File

@ -5,7 +5,7 @@
#if USE_MYSQL
#include <mysqlxx/Pool.h>
#include <Core/MySQLClient.h>
#include <Core/MySQL/MySQLClient.h>
#include <Databases/IDatabase.h>
#include <Databases/MySQL/MaterializeMySQLSettings.h>
#include <Databases/MySQL/MaterializeMySQLSyncThread.h>

View File

@ -7,7 +7,7 @@
#if USE_MYSQL
#include <Core/Types.h>
#include <Core/MySQLReplication.h>
#include <Core/MySQL/MySQLReplication.h>
#include <mysqlxx/Connection.h>
#include <mysqlxx/PoolWithFailover.h>

View File

@ -8,7 +8,7 @@
# include <mutex>
# include <Core/BackgroundSchedulePool.h>
# include <Core/MySQLClient.h>
# include <Core/MySQL/MySQLClient.h>
# include <DataStreams/BlockIO.h>
# include <DataTypes/DataTypeString.h>
# include <DataTypes/DataTypesNumber.h>

View File

@ -1,4 +1,4 @@
#include <Core/MySQL/PacketPayloadReadBuffer.h>
#include <IO/MySQLPacketPayloadReadBuffer.h>
#include <sstream>
namespace DB
@ -9,17 +9,16 @@ namespace ErrorCodes
extern const int UNKNOWN_PACKET_FROM_CLIENT;
}
namespace MySQLProtocol
{
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
PacketPayloadReadBuffer::PacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_)
MySQLPacketPayloadReadBuffer::MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_)
: ReadBuffer(in_.position(), 0), in(in_), sequence_id(sequence_id_) // not in.buffer().begin(), because working buffer may include previous packet
{
}
bool PacketPayloadReadBuffer::nextImpl()
bool MySQLPacketPayloadReadBuffer::nextImpl()
{
if (!has_read_header || (payload_length == max_packet_size && offset == payload_length))
if (!has_read_header || (payload_length == MAX_PACKET_LENGTH && offset == payload_length))
{
has_read_header = true;
working_buffer.resize(0);
@ -27,7 +26,7 @@ bool PacketPayloadReadBuffer::nextImpl()
payload_length = 0;
in.readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > max_packet_size)
if (payload_length > MAX_PACKET_LENGTH)
{
std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
@ -64,5 +63,3 @@ bool PacketPayloadReadBuffer::nextImpl()
}
}
}

View File

@ -5,20 +5,14 @@
namespace DB
{
namespace MySQLProtocol
{
extern const size_t MAX_PACKET_LENGTH;
/** Reading packets.
* Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload.
*/
class PacketPayloadReadBuffer : public ReadBuffer
class MySQLPacketPayloadReadBuffer : public ReadBuffer
{
private:
ReadBuffer & in;
uint8_t & sequence_id;
const size_t max_packet_size = MAX_PACKET_LENGTH;
bool has_read_header = false;
@ -32,10 +26,8 @@ protected:
bool nextImpl() override;
public:
PacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_);
MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_);
};
}
}

View File

@ -1,14 +1,11 @@
#include <Core/MySQL/PacketPayloadWriteBuffer.h>
#include <IO/MySQLPacketPayloadWriteBuffer.h>
namespace DB
{
namespace MySQLProtocol
{
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
extern const size_t MAX_PACKET_LENGTH;
PacketPayloadWriteBuffer::PacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_)
MySQLPacketPayloadWriteBuffer::MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_)
: WriteBuffer(out_.position(), 0), out(out_), sequence_id(sequence_id_), total_left(payload_length_)
{
startNewPacket();
@ -16,7 +13,7 @@ PacketPayloadWriteBuffer::PacketPayloadWriteBuffer(WriteBuffer & out_, size_t pa
pos = out.position();
}
void PacketPayloadWriteBuffer::startNewPacket()
void MySQLPacketPayloadWriteBuffer::startNewPacket()
{
payload_length = std::min(total_left, MAX_PACKET_LENGTH);
bytes_written = 0;
@ -27,7 +24,7 @@ void PacketPayloadWriteBuffer::startNewPacket()
bytes += 4;
}
void PacketPayloadWriteBuffer::setWorkingBuffer()
void MySQLPacketPayloadWriteBuffer::setWorkingBuffer()
{
out.nextIfAtEnd();
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
@ -40,7 +37,7 @@ void PacketPayloadWriteBuffer::setWorkingBuffer()
}
}
void PacketPayloadWriteBuffer::nextImpl()
void MySQLPacketPayloadWriteBuffer::nextImpl()
{
const int written = pos - working_buffer.begin();
if (eof)
@ -57,5 +54,3 @@ void PacketPayloadWriteBuffer::nextImpl()
}
}
}

View File

@ -5,16 +5,13 @@
namespace DB
{
namespace MySQLProtocol
{
/** Writing packets.
* https://dev.mysql.com/doc/internals/en/mysql-packet.html
*/
class PacketPayloadWriteBuffer : public WriteBuffer
class MySQLPacketPayloadWriteBuffer : public WriteBuffer
{
public:
PacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_);
MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_);
bool remainingPayloadSize() { return total_left; }
@ -37,5 +34,3 @@ private:
};
}
}

View File

@ -30,17 +30,17 @@ void MySQLOutputFormat::initialize()
if (header.columns())
{
packet_sender->sendPacket(LengthEncodedNumber(header.columns()));
packet_endpoint->sendPacket(LengthEncodedNumber(header.columns()));
for (size_t i = 0; i < header.columns(); i++)
{
const auto & column_name = header.getColumnsWithTypeAndName()[i].name;
packet_sender->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
}
if (!(context->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOFPacket(0, 0));
packet_endpoint->sendPacket(EOFPacket(0, 0));
}
}
}
@ -54,7 +54,7 @@ void MySQLOutputFormat::consume(Chunk chunk)
for (size_t i = 0; i < chunk.getNumRows(); i++)
{
ProtocolText::ResultSetRow row_packet(data_types, chunk.getColumns(), i);
packet_sender->sendPacket(row_packet);
packet_endpoint->sendPacket(row_packet);
}
}
@ -76,17 +76,17 @@ void MySQLOutputFormat::finalize()
const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0)
packet_sender->sendPacket(OKPacket(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
packet_endpoint->sendPacket(OKPacket(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
if (context->mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OKPacket(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
packet_endpoint->sendPacket(OKPacket(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
packet_sender->sendPacket(EOFPacket(0, 0), true);
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
}
void MySQLOutputFormat::flush()
{
packet_sender->out->next();
packet_endpoint->out->next();
}
void registerOutputFormatProcessorMySQLWire(FormatFactory & factory)

View File

@ -29,8 +29,7 @@ public:
void setContext(const Context & context_)
{
context = &context_;
packet_sender = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(context_.mysql.sequence_id)); /// TODO: fix it
packet_sender->max_packet_size = context_.mysql.max_packet_size;
packet_endpoint = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(context_.mysql.sequence_id)); /// TODO: fix it
}
void consume(Chunk) override;
@ -45,7 +44,7 @@ private:
bool initialized = false;
const Context * context = nullptr;
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_sender;
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
FormatSettings format_settings;
DataTypes data_types;
};

View File

@ -41,9 +41,9 @@ namespace DB
{
using namespace MySQLProtocol;
using namespace MySQLProtocol::ConnectionPhase;
using namespace MySQLProtocol::ProtocolText;
using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::ProtocolText;
using namespace MySQLProtocol::ConnectionPhase;
#if USE_SSL
using Poco::Net::SecureStreamSocket;
@ -91,13 +91,13 @@ void MySQLHandler::run()
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_sender = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id);
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id);
try
{
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME,
auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci);
packet_sender->sendPacket<Handshake>(handshake, true);
packet_endpoint->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake");
@ -135,16 +135,16 @@ void MySQLHandler::run()
catch (const Exception & exc)
{
log->log(exc);
packet_sender->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
}
OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
packet_sender->sendPacket(ok_packet, true);
packet_endpoint->sendPacket(ok_packet, true);
while (true)
{
packet_sender->resetSequenceId();
PacketPayloadReadBuffer payload = packet_sender->getPayload();
packet_endpoint->resetSequenceId();
MySQLPacketPayloadReadBuffer payload = packet_endpoint->getPayload();
char command = 0;
payload.readStrict(command);
@ -184,7 +184,7 @@ void MySQLHandler::run()
}
catch (...)
{
packet_sender->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true);
packet_endpoint->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true);
}
}
}
@ -233,11 +233,11 @@ void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResp
packet_size = PACKET_HEADER_SIZE + payload_size;
WriteBufferFromOwnString buf_for_handshake_response;
buf_for_handshake_response.write(buf, pos);
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
copyData(*packet_endpoint->in, buf_for_handshake_response, packet_size - pos);
ReadBufferFromString payload(buf_for_handshake_response.str());
payload.ignore(PACKET_HEADER_SIZE);
packet.readPayloadWithUnpacked(payload);
packet_sender->sequence_id++;
packet_endpoint->sequence_id++;
}
}
@ -254,12 +254,12 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
}
std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
auth_plugin->authenticate(user_name, auth_response, connection_context, packet_sender, secure_connection, socket().peerAddress());
auth_plugin->authenticate(user_name, auth_response, connection_context, packet_endpoint, secure_connection, socket().peerAddress());
}
catch (const Exception & exc)
{
LOG_ERROR(log, "Authentication for user {} failed.", user_name);
packet_sender->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
throw;
}
LOG_INFO(log, "Authentication for user {} succeeded.", user_name);
@ -271,7 +271,7 @@ void MySQLHandler::comInitDB(ReadBuffer & payload)
readStringUntilEOF(database, payload);
LOG_DEBUG(log, "Setting current database to {}", database);
connection_context.setCurrentDatabase(database);
packet_sender->sendPacket(OKPacket(0, client_capability_flags, 0, 0, 1), true);
packet_endpoint->sendPacket(OKPacket(0, client_capability_flags, 0, 0, 1), true);
}
void MySQLHandler::comFieldList(ReadBuffer & payload)
@ -286,14 +286,14 @@ void MySQLHandler::comFieldList(ReadBuffer & payload)
ColumnDefinition column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
);
packet_sender->sendPacket(column_definition);
packet_endpoint->sendPacket(column_definition);
}
packet_sender->sendPacket(OKPacket(0xfe, client_capability_flags, 0, 0, 0), true);
packet_endpoint->sendPacket(OKPacket(0xfe, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comPing()
{
packet_sender->sendPacket(OKPacket(0x0, client_capability_flags, 0, 0, 0), true);
packet_endpoint->sendPacket(OKPacket(0x0, client_capability_flags, 0, 0, 0), true);
}
static bool isFederatedServerSetupSetCommand(const String & query);
@ -306,7 +306,7 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
// As Clickhouse doesn't support these statements, we just send OK packet in response.
if (isFederatedServerSetupSetCommand(query))
{
packet_sender->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
packet_endpoint->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
}
else
{
@ -336,7 +336,7 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
);
if (!with_output)
packet_sender->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
packet_endpoint->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
}
}
@ -380,9 +380,8 @@ void MySQLHandlerSSL::finishHandshakeSSL(
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context.mysql.sequence_id = 2;
packet_sender = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id);
packet_sender->max_packet_size = connection_context.mysql.max_packet_size;
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id);
packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
}
#endif

View File

@ -58,7 +58,7 @@ protected:
Context connection_context;
std::shared_ptr<MySQLProtocol::PacketEndpoint> packet_sender;
std::shared_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
private:
size_t connection_id = 0;