add error to MySQL client connect

This commit is contained in:
BohuTANG 2020-04-20 23:06:17 +08:00 committed by zhang2014
parent ab8b847e66
commit 1a9118d722
4 changed files with 44 additions and 71 deletions

View File

@ -2,13 +2,10 @@
namespace DB namespace DB
{ {
using namespace MySQLProtocol;
using namespace MySQLProtocol::Authentication; using namespace MySQLProtocol::Authentication;
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int NETWORK_ERROR;
extern const int SOCKET_TIMEOUT;
extern const int UNKNOWN_PACKET_FROM_SERVER; extern const int UNKNOWN_PACKET_FROM_SERVER;
} }
@ -18,36 +15,22 @@ MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION; client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
} }
void MySQLClient::connect() bool MySQLClient::connect()
{ {
try if (connected)
{
if (connected)
{
close();
}
socket = std::make_unique<Poco::Net::StreamSocket>();
address = DNSResolver::instance().resolveAddress(host, port);
socket->connect(*address);
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
connected = true;
handshake();
}
catch (Poco::Net::NetException & e)
{ {
close(); close();
throw NetException(e.displayText(), ErrorCodes::NETWORK_ERROR);
}
catch (Poco::TimeoutException & e)
{
close();
throw NetException(e.displayText(), ErrorCodes::SOCKET_TIMEOUT);
} }
socket = std::make_unique<Poco::Net::StreamSocket>();
address = DNSResolver::instance().resolveAddress(host, port);
socket->connect(*address);
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
connected = true;
return handshake();
} }
void MySQLClient::close() void MySQLClient::close()
@ -61,7 +44,7 @@ void MySQLClient::close()
} }
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html /// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
void MySQLClient::handshake() bool MySQLClient::handshake()
{ {
Handshake handshake; Handshake handshake;
packet_sender->receivePacket(handshake); packet_sender->receivePacket(handshake);
@ -81,16 +64,19 @@ void MySQLClient::handshake()
PacketResponse packetResponse(handshake.capability_flags); PacketResponse packetResponse(handshake.capability_flags);
packet_sender->receivePacket(packetResponse); packet_sender->receivePacket(packetResponse);
if (packetResponse.getType() != PACKET_ERR)
switch (packetResponse.getType())
{ {
case PACKET_OK: return true;
break; }
case PACKET_ERR: else
throw Exception(packetResponse.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER); {
break; last_error = packetResponse.err.error_message;
case PACKET_EOF: return false;
break;
} }
} }
String MySQLClient::error()
{
return last_error;
}
} }

View File

@ -14,14 +14,15 @@
namespace DB namespace DB
{ {
using namespace MySQLProtocol;
class MySQLClient class MySQLClient
{ {
public: public:
MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_, const String & database_); MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_, const String & database_);
bool connect();
void connect();
void close(); void close();
String error();
private: private:
String host; String host;
@ -32,6 +33,7 @@ private:
bool connected = false; bool connected = false;
UInt32 client_capability_flags = 0; UInt32 client_capability_flags = 0;
String last_error;
uint8_t seq = 0; uint8_t seq = 0;
UInt8 charset_utf8 = 33; UInt8 charset_utf8 = 33;
@ -42,10 +44,8 @@ private:
std::shared_ptr<WriteBuffer> out; std::shared_ptr<WriteBuffer> out;
std::unique_ptr<Poco::Net::StreamSocket> socket; std::unique_ptr<Poco::Net::StreamSocket> socket;
std::optional<Poco::Net::SocketAddress> address; std::optional<Poco::Net::SocketAddress> address;
std::shared_ptr<PacketSender> packet_sender;
void handshake(); bool handshake();
protected:
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
}; };
} }

View File

@ -491,49 +491,30 @@ public:
void readPayloadImpl(ReadBuffer & buffer) override void readPayloadImpl(ReadBuffer & buffer) override
{ {
/// 1-byte: [0a] protocol version
buffer.readStrict(reinterpret_cast<char *>(&protocol_version), 1); buffer.readStrict(reinterpret_cast<char *>(&protocol_version), 1);
/// string[NUL]: server version
readNullTerminated(server_version, buffer); readNullTerminated(server_version, buffer);
/// 4-bytes: connection id
buffer.readStrict(reinterpret_cast<char *>(&connection_id), 4); buffer.readStrict(reinterpret_cast<char *>(&connection_id), 4);
/// 8-bytes: auth-plugin-data-part-1
auth_plugin_data.resize(AUTH_PLUGIN_DATA_PART_1_LENGTH); auth_plugin_data.resize(AUTH_PLUGIN_DATA_PART_1_LENGTH);
buffer.readStrict(auth_plugin_data.data(), AUTH_PLUGIN_DATA_PART_1_LENGTH); buffer.readStrict(auth_plugin_data.data(), AUTH_PLUGIN_DATA_PART_1_LENGTH);
/// 1-byte: [00] filler
buffer.ignore(1); buffer.ignore(1);
/// 2-bytes: capability flags lower 2-bytes
buffer.readStrict(reinterpret_cast<char *>(&capability_flags), 2); buffer.readStrict(reinterpret_cast<char *>(&capability_flags), 2);
/// 1-byte: character set
buffer.readStrict(reinterpret_cast<char *>(&character_set), 1); buffer.readStrict(reinterpret_cast<char *>(&character_set), 1);
/// 2-bytes: status flags(ignored)
buffer.readStrict(reinterpret_cast<char *>(&status_flags), 2); buffer.readStrict(reinterpret_cast<char *>(&status_flags), 2);
/// 2-bytes: capability flags upper 2-bytes
buffer.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2); buffer.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2);
UInt8 auth_plugin_data_length = 0; UInt8 auth_plugin_data_length = 0;
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH) if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
{ {
/// 1-byte: length of auth-plugin-data
buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1); buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1);
} }
else else
{ {
/// 1-byte: [00]
buffer.ignore(1); buffer.ignore(1);
} }
/// string[10] reserved (all [00])
buffer.ignore(10); buffer.ignore(10);
if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION) if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION)
{ {
UInt8 part2_length = (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH) > 13 UInt8 part2_length = (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH) > 13
@ -764,7 +745,7 @@ public:
String session_state_changes; String session_state_changes;
String info; String info;
OK_Packet(uint32_t capabilities_) : capabilities(capabilities_) { } OK_Packet(uint32_t capabilities_) : header(0x00), capabilities(capabilities_), affected_rows(0), last_insert_id(0), status_flags(0) { }
OK_Packet( OK_Packet(
uint8_t header_, uint8_t header_,
uint32_t capabilities_, uint32_t capabilities_,
@ -776,6 +757,7 @@ public:
: header(header_) : header(header_)
, capabilities(capabilities_) , capabilities(capabilities_)
, affected_rows(affected_rows_) , affected_rows(affected_rows_)
, last_insert_id(0)
, warnings(warnings_) , warnings(warnings_)
, status_flags(status_flags_) , status_flags(status_flags_)
, session_state_changes(std::move(session_state_changes_)) , session_state_changes(std::move(session_state_changes_))
@ -814,7 +796,7 @@ public:
{ {
buffer.write(header); buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer); writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(0, buffer); /// last insert-id writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41) if (capabilities & CLIENT_PROTOCOL_41)
{ {
@ -859,7 +841,8 @@ public:
auto len = readLengthEncodedNumber(payload); auto len = readLengthEncodedNumber(payload);
info.resize(len); info.resize(len);
payload.readStrict(info.data(), len); payload.readStrict(info.data(), len);
if (status_flags & SERVER_SESSION_STATE_CHANGED) { if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
len = readLengthEncodedNumber(payload); len = readLengthEncodedNumber(payload);
session_state_changes.resize(len); session_state_changes.resize(len);
payload.readStrict(session_state_changes.data(), len); payload.readStrict(session_state_changes.data(), len);

View File

@ -1,3 +1,4 @@
#include <iostream>
#include <Core/MySQLClient.h> #include <Core/MySQLClient.h>
@ -5,9 +6,12 @@ int main(int, char **)
{ {
using namespace DB; using namespace DB;
UInt16 port = 4407; UInt16 port = 9001;
String host = "127.0.0.1", user = "mock", password = "mock"; String host = "127.0.0.1", user = "default", password = "123";
MySQLClient client(host, port, user, password, ""); MySQLClient client(host, port, user, password, "");
client.connect(); if (!client.connect())
{
std::cerr << "Connect Error: " << client.error() << std::endl;
}
return 0; return 0;
} }