add error to MySQL client connect

This commit is contained in:
BohuTANG 2020-04-20 23:06:17 +08:00 committed by zhang2014
parent ab8b847e66
commit 1a9118d722
4 changed files with 44 additions and 71 deletions

View File

@ -2,13 +2,10 @@
namespace DB
{
using namespace MySQLProtocol;
using namespace MySQLProtocol::Authentication;
namespace ErrorCodes
{
extern const int NETWORK_ERROR;
extern const int SOCKET_TIMEOUT;
extern const int UNKNOWN_PACKET_FROM_SERVER;
}
@ -18,36 +15,22 @@ MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
}
void MySQLClient::connect()
bool MySQLClient::connect()
{
try
{
if (connected)
{
close();
}
socket = std::make_unique<Poco::Net::StreamSocket>();
address = DNSResolver::instance().resolveAddress(host, port);
socket->connect(*address);
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
connected = true;
handshake();
}
catch (Poco::Net::NetException & e)
if (connected)
{
close();
throw NetException(e.displayText(), ErrorCodes::NETWORK_ERROR);
}
catch (Poco::TimeoutException & e)
{
close();
throw NetException(e.displayText(), ErrorCodes::SOCKET_TIMEOUT);
}
socket = std::make_unique<Poco::Net::StreamSocket>();
address = DNSResolver::instance().resolveAddress(host, port);
socket->connect(*address);
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
connected = true;
return handshake();
}
void MySQLClient::close()
@ -61,7 +44,7 @@ void MySQLClient::close()
}
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
void MySQLClient::handshake()
bool MySQLClient::handshake()
{
Handshake handshake;
packet_sender->receivePacket(handshake);
@ -81,16 +64,19 @@ void MySQLClient::handshake()
PacketResponse packetResponse(handshake.capability_flags);
packet_sender->receivePacket(packetResponse);
switch (packetResponse.getType())
if (packetResponse.getType() != PACKET_ERR)
{
case PACKET_OK:
break;
case PACKET_ERR:
throw Exception(packetResponse.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
break;
case PACKET_EOF:
break;
return true;
}
else
{
last_error = packetResponse.err.error_message;
return false;
}
}
String MySQLClient::error()
{
return last_error;
}
}

View File

@ -14,14 +14,15 @@
namespace DB
{
using namespace MySQLProtocol;
class MySQLClient
{
public:
MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_, const String & database_);
void connect();
bool connect();
void close();
String error();
private:
String host;
@ -32,6 +33,7 @@ private:
bool connected = false;
UInt32 client_capability_flags = 0;
String last_error;
uint8_t seq = 0;
UInt8 charset_utf8 = 33;
@ -42,10 +44,8 @@ private:
std::shared_ptr<WriteBuffer> out;
std::unique_ptr<Poco::Net::StreamSocket> socket;
std::optional<Poco::Net::SocketAddress> address;
std::shared_ptr<PacketSender> packet_sender;
void handshake();
protected:
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
bool handshake();
};
}

View File

@ -491,49 +491,30 @@ public:
void readPayloadImpl(ReadBuffer & buffer) override
{
/// 1-byte: [0a] protocol version
buffer.readStrict(reinterpret_cast<char *>(&protocol_version), 1);
/// string[NUL]: server version
readNullTerminated(server_version, buffer);
/// 4-bytes: connection id
buffer.readStrict(reinterpret_cast<char *>(&connection_id), 4);
/// 8-bytes: auth-plugin-data-part-1
auth_plugin_data.resize(AUTH_PLUGIN_DATA_PART_1_LENGTH);
buffer.readStrict(auth_plugin_data.data(), AUTH_PLUGIN_DATA_PART_1_LENGTH);
/// 1-byte: [00] filler
buffer.ignore(1);
/// 2-bytes: capability flags lower 2-bytes
buffer.readStrict(reinterpret_cast<char *>(&capability_flags), 2);
/// 1-byte: character set
buffer.readStrict(reinterpret_cast<char *>(&character_set), 1);
/// 2-bytes: status flags(ignored)
buffer.readStrict(reinterpret_cast<char *>(&status_flags), 2);
/// 2-bytes: capability flags upper 2-bytes
buffer.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2);
UInt8 auth_plugin_data_length = 0;
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
{
/// 1-byte: length of auth-plugin-data
buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1);
}
else
{
/// 1-byte: [00]
buffer.ignore(1);
}
/// string[10] reserved (all [00])
buffer.ignore(10);
if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION)
{
UInt8 part2_length = (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH) > 13
@ -764,7 +745,7 @@ public:
String session_state_changes;
String info;
OK_Packet(uint32_t capabilities_) : capabilities(capabilities_) { }
OK_Packet(uint32_t capabilities_) : header(0x00), capabilities(capabilities_), affected_rows(0), last_insert_id(0), status_flags(0) { }
OK_Packet(
uint8_t header_,
uint32_t capabilities_,
@ -776,6 +757,7 @@ public:
: header(header_)
, capabilities(capabilities_)
, affected_rows(affected_rows_)
, last_insert_id(0)
, warnings(warnings_)
, status_flags(status_flags_)
, session_state_changes(std::move(session_state_changes_))
@ -814,7 +796,7 @@ public:
{
buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(0, buffer); /// last insert-id
writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
@ -859,7 +841,8 @@ public:
auto len = readLengthEncodedNumber(payload);
info.resize(len);
payload.readStrict(info.data(), len);
if (status_flags & SERVER_SESSION_STATE_CHANGED) {
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
len = readLengthEncodedNumber(payload);
session_state_changes.resize(len);
payload.readStrict(session_state_changes.data(), len);

View File

@ -1,3 +1,4 @@
#include <iostream>
#include <Core/MySQLClient.h>
@ -5,9 +6,12 @@ int main(int, char **)
{
using namespace DB;
UInt16 port = 4407;
String host = "127.0.0.1", user = "mock", password = "mock";
UInt16 port = 9001;
String host = "127.0.0.1", user = "default", password = "123";
MySQLClient client(host, port, user, password, "");
client.connect();
if (!client.connect())
{
std::cerr << "Connect Error: " << client.error() << std::endl;
}
return 0;
}