Simplified prepared statements handling for MySQL interface

This commit is contained in:
slvrtrn 2023-08-08 23:48:23 +02:00
parent e528eab7f5
commit d8904ffa69
5 changed files with 163 additions and 1 deletions

View File

@ -54,6 +54,9 @@ enum Command
COM_CHANGE_USER = 0x11,
COM_BINLOG_DUMP = 0x12,
COM_REGISTER_SLAVE = 0x15,
COM_STMT_PREPARE = 0x16,
COM_STMT_EXECUTE = 0x17,
COM_STMT_CLOSE = 0x19,
COM_RESET_CONNECTION = 0x1f,
COM_DAEMON = 0x1d,
COM_BINLOG_DUMP_GTID = 0x1e

View File

@ -0,0 +1,40 @@
#include <Core/MySQL/PacketsPreparedStatements.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace MySQLProtocol
{
namespace PreparedStatements
{
size_t PrepareStatementResponseOK::getPayloadSize() const
{
return 13;
}
void PrepareStatementResponseOK::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(reinterpret_cast<const char *>(&status), 1);
buffer.write(reinterpret_cast<const char *>(&statement_id), 4);
buffer.write(reinterpret_cast<const char *>(&num_columns), 2);
buffer.write(reinterpret_cast<const char *>(&num_params), 2);
buffer.write(reinterpret_cast<const char *>(&reserved_1), 1);
buffer.write(reinterpret_cast<const char *>(&warnings_count), 2);
buffer.write(0x0); // RESULTSET_METADATA_NONE
}
void PrepareStatementResponseOK::readPayloadImpl([[maybe_unused]] ReadBuffer & payload)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "PrepareStatementResponseOK::readPayloadImpl is not implemented");
}
PrepareStatementResponseOK::PrepareStatementResponseOK(
uint32_t statement_id_, uint16_t num_columns_, uint16_t num_params_, uint16_t warnings_count_)
: statement_id(statement_id_), num_columns(num_columns_), num_params(num_params_), warnings_count(warnings_count_)
{
}
}
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include <Core/MySQL/IMySQLReadPacket.h>
#include <Core/MySQL/IMySQLWritePacket.h>
namespace DB
{
namespace MySQLProtocol
{
namespace PreparedStatements
{
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response_ok
class PrepareStatementResponseOK : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
uint8_t status = 0x00;
uint32_t statement_id;
uint16_t num_columns;
uint16_t num_params;
uint8_t reserved_1 = 0;
uint16_t warnings_count;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
PrepareStatementResponseOK(uint32_t statement_id_, uint16_t num_columns_, uint16_t num_params_, uint16_t warnings_count_);
};
}
}
}

View File

@ -4,6 +4,7 @@
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsPreparedStatements.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Core/NamesAndTypes.h>
@ -40,6 +41,7 @@ using namespace MySQLProtocol;
using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::ProtocolText;
using namespace MySQLProtocol::ConnectionPhase;
using namespace MySQLProtocol::PreparedStatements;
#if USE_SSL
using Poco::Net::SecureStreamSocket;
@ -181,6 +183,15 @@ void MySQLHandler::run()
case COM_PING:
comPing();
break;
case COM_STMT_PREPARE:
comStmtPrepare(payload);
break;
case COM_STMT_EXECUTE:
comStmtExecute(payload);
break;
case COM_STMT_CLOSE:
comStmtClose(payload);
break;
default:
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Command {} is not implemented.", command);
}
@ -254,7 +265,8 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
{
try
{
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible
// (if password is specified using double SHA1). Otherwise, SHA256 plugin is used.
if (session->getAuthenticationTypeOrLogInFailure(user_name) == DB::AuthenticationType::SHA256_PASSWORD)
{
authPluginSSL();
@ -371,6 +383,68 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
}
}
void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload)
{
String query;
readStringUntilEOF(query, payload);
uint32_t statement_id = current_prepared_statement_id;
if (current_prepared_statement_id == std::numeric_limits<uint32_t>::max())
{
current_prepared_statement_id = 0;
}
else
{
current_prepared_statement_id++;
}
// Key collisions should not happen here, as we remove the elements from the map with COM_STMT_CLOSE,
// and we have quite a big range of available identifiers with 32-bit unsigned integer
if (prepared_statements_map.contains(statement_id)) [[unlikely]]
{
LOG_ERROR(
log,
"Failed to store a new statement `{}` with id {}; it is already taken by `{}`",
query,
statement_id,
prepared_statements_map.at(statement_id));
packet_endpoint->sendPacket(ERRPacket(), true);
return;
}
prepared_statements_map.emplace(statement_id, query);
packet_endpoint->sendPacket(PrepareStatementResponseOK(statement_id, 0, 0, 0), true);
}
void MySQLHandler::comStmtExecute(ReadBuffer & payload)
{
uint32_t statement_id;
payload.readStrict(reinterpret_cast<char *>(&statement_id), 4);
if (!prepared_statements_map.contains(statement_id)) [[unlikely]]
{
LOG_ERROR(log, "Could not find prepared statement with id {}", statement_id);
packet_endpoint->sendPacket(ERRPacket(), true);
return;
}
// Temporary workaround as we work only with queries that do not bind any parameters atm
ReadBufferFromString com_query_payload(prepared_statements_map.at(statement_id));
MySQLHandler::comQuery(com_query_payload);
};
void MySQLHandler::comStmtClose([[maybe_unused]] ReadBuffer & payload) {
uint32_t statement_id;
payload.readStrict(reinterpret_cast<char *>(&statement_id), 4);
if (prepared_statements_map.contains(statement_id)) {
prepared_statements_map.erase(statement_id);
}
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_close.html
// No response packet is sent back to the client.
};
void MySQLHandler::authPluginSSL()
{
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,

View File

@ -56,6 +56,12 @@ protected:
void authenticate(const String & user_name, const String & auth_plugin_name, const String & auth_response);
void comStmtPrepare(ReadBuffer & payload);
void comStmtExecute(ReadBuffer & payload);
void comStmtClose(ReadBuffer & payload);
virtual void authPluginSSL();
virtual void finishHandshakeSSL(size_t packet_size, char * buf, size_t pos, std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet);
@ -76,6 +82,10 @@ protected:
using Replacements = std::unordered_map<std::string, ReplacementFn>;
Replacements replacements;
uint32_t current_prepared_statement_id = 0;
using PreparedStatementsMap = std::unordered_map<uint32_t, String>;
PreparedStatementsMap prepared_statements_map;
std::unique_ptr<MySQLProtocol::Authentication::IPlugin> auth_plugin;
std::shared_ptr<ReadBufferFromPocoSocket> in;
std::shared_ptr<WriteBuffer> out;