ClickHouse/base/mysqlxx/Pool.cpp
Haavard Kvaalen efa94c5735 Fix use after free in mysqlxx::Pool::tryGet()
tryGet() created an Entry object that referenced a Connection, but then
deleted the Connection just before the Entry went out of scope.
2021-03-04 16:51:27 +01:00

343 lines
10 KiB
C++

#if __has_include(<mysql.h>)
#include <mysql.h>
#include <mysqld_error.h>
#else
#include <mysql/mysql.h>
#include <mysql/mysqld_error.h>
#endif
#include <mysqlxx/Pool.h>
#include <common/sleep.h>
#include <Poco/Util/LayeredConfiguration.h>
namespace mysqlxx
{
void Pool::Entry::incrementRefCount()
{
if (!data)
return;
/// First reference, initialize thread
if (data->ref_count.fetch_add(1) == 0)
mysql_thread_init();
}
void Pool::Entry::decrementRefCount()
{
if (!data)
return;
/// We were the last user of this thread, deinitialize it
if (data->ref_count.fetch_sub(1) == 1)
mysql_thread_end();
}
Pool::Pool(const Poco::Util::AbstractConfiguration & cfg, const std::string & config_name,
unsigned default_connections_, unsigned max_connections_,
const char * parent_config_name_)
: logger(Poco::Logger::get("mysqlxx::Pool"))
, default_connections(default_connections_)
, max_connections(max_connections_)
{
server = cfg.getString(config_name + ".host");
if (parent_config_name_)
{
const std::string parent_config_name(parent_config_name_);
db = cfg.getString(config_name + ".db", cfg.getString(parent_config_name + ".db", ""));
user = cfg.has(config_name + ".user")
? cfg.getString(config_name + ".user")
: cfg.getString(parent_config_name + ".user");
password = cfg.has(config_name + ".password")
? cfg.getString(config_name + ".password")
: cfg.getString(parent_config_name + ".password");
if (!cfg.has(config_name + ".port") && !cfg.has(config_name + ".socket")
&& !cfg.has(parent_config_name + ".port") && !cfg.has(parent_config_name + ".socket"))
throw Poco::Exception("mysqlxx::Pool configuration: expected port or socket");
port = cfg.has(config_name + ".port")
? cfg.getInt(config_name + ".port")
: cfg.getInt(parent_config_name + ".port", 0);
socket = cfg.has(config_name + ".socket")
? cfg.getString(config_name + ".socket")
: cfg.getString(parent_config_name + ".socket", "");
ssl_ca = cfg.has(config_name + ".ssl_ca")
? cfg.getString(config_name + ".ssl_ca")
: cfg.getString(parent_config_name + ".ssl_ca", "");
ssl_cert = cfg.has(config_name + ".ssl_cert")
? cfg.getString(config_name + ".ssl_cert")
: cfg.getString(parent_config_name + ".ssl_cert", "");
ssl_key = cfg.has(config_name + ".ssl_key")
? cfg.getString(config_name + ".ssl_key")
: cfg.getString(parent_config_name + ".ssl_key", "");
enable_local_infile = cfg.getBool(config_name + ".enable_local_infile",
cfg.getBool(parent_config_name + ".enable_local_infile", MYSQLXX_DEFAULT_ENABLE_LOCAL_INFILE));
opt_reconnect = cfg.getBool(config_name + ".opt_reconnect",
cfg.getBool(parent_config_name + ".opt_reconnect", MYSQLXX_DEFAULT_MYSQL_OPT_RECONNECT));
}
else
{
db = cfg.getString(config_name + ".db", "");
user = cfg.getString(config_name + ".user");
password = cfg.getString(config_name + ".password");
if (!cfg.has(config_name + ".port") && !cfg.has(config_name + ".socket"))
throw Poco::Exception("mysqlxx::Pool configuration: expected port or socket");
port = cfg.getInt(config_name + ".port", 0);
socket = cfg.getString(config_name + ".socket", "");
ssl_ca = cfg.getString(config_name + ".ssl_ca", "");
ssl_cert = cfg.getString(config_name + ".ssl_cert", "");
ssl_key = cfg.getString(config_name + ".ssl_key", "");
enable_local_infile = cfg.getBool(
config_name + ".enable_local_infile", MYSQLXX_DEFAULT_ENABLE_LOCAL_INFILE);
opt_reconnect = cfg.getBool(config_name + ".opt_reconnect", MYSQLXX_DEFAULT_MYSQL_OPT_RECONNECT);
}
connect_timeout = cfg.getInt(config_name + ".connect_timeout",
cfg.getInt("mysql_connect_timeout",
MYSQLXX_DEFAULT_TIMEOUT));
rw_timeout =
cfg.getInt(config_name + ".rw_timeout",
cfg.getInt("mysql_rw_timeout",
MYSQLXX_DEFAULT_RW_TIMEOUT));
}
Pool::~Pool()
{
std::lock_guard<std::mutex> lock(mutex);
for (auto & connection : connections)
delete static_cast<Connection *>(connection);
}
Pool::Entry Pool::get()
{
std::unique_lock<std::mutex> lock(mutex);
initialize();
for (;;)
{
logger.trace("(%s): Iterating through existing MySQL connections", getDescription());
for (auto & connection : connections)
{
if (connection->ref_count == 0)
return Entry(connection, this);
}
logger.trace("(%s): Trying to allocate a new connection.", getDescription());
if (connections.size() < static_cast<size_t>(max_connections))
{
Connection * conn = allocConnection();
if (conn)
return Entry(conn, this);
logger.trace("(%s): Unable to create a new connection: Allocation failed.", getDescription());
}
else
{
logger.trace("(%s): Unable to create a new connection: Max number of connections has been reached.", getDescription());
}
lock.unlock();
logger.trace("(%s): Sleeping for %d seconds.", getDescription(), MYSQLXX_POOL_SLEEP_ON_CONNECT_FAIL);
sleepForSeconds(MYSQLXX_POOL_SLEEP_ON_CONNECT_FAIL);
lock.lock();
}
}
Pool::Entry Pool::tryGet()
{
std::lock_guard<std::mutex> lock(mutex);
initialize();
/// Try to pick an idle connection from already allocated
for (auto connection_it = connections.cbegin(); connection_it != connections.cend();)
{
Connection * connection_ptr = *connection_it;
/// Fixme: There is a race condition here b/c we do not synchronize with Pool::Entry's copy-assignment operator
if (connection_ptr->ref_count == 0)
{
{
Entry res(connection_ptr, this);
if (res.tryForceConnected()) /// Tries to reestablish connection as well
return res;
}
logger.debug("(%s): Idle connection to MySQL server cannot be recovered, dropping it.", getDescription());
/// This one is disconnected, cannot be reestablished and so needs to be disposed of.
connection_it = connections.erase(connection_it);
::delete connection_ptr; /// TODO: Manual memory management is awkward (matches allocConnection() method)
}
else
++connection_it;
}
if (connections.size() >= max_connections)
throw Poco::Exception("mysqlxx::Pool is full");
Connection * connection_ptr = allocConnection(true);
if (connection_ptr)
return {connection_ptr, this};
return {};
}
void Pool::removeConnection(Connection* connection)
{
logger.trace("(%s): Removing connection.", getDescription());
std::lock_guard<std::mutex> lock(mutex);
if (connection)
{
if (connection->ref_count > 0)
{
connection->conn.disconnect();
connection->ref_count = 0;
}
connections.remove(connection);
}
}
void Pool::Entry::disconnect()
{
pool->removeConnection(data);
}
void Pool::Entry::forceConnected() const
{
if (data == nullptr)
throw Poco::RuntimeException("Tried to access NULL database connection.");
bool first = true;
while (!tryForceConnected())
{
if (first)
first = false;
else
sleepForSeconds(MYSQLXX_POOL_SLEEP_ON_CONNECT_FAIL);
pool->logger.debug("Entry: Reconnecting to MySQL server %s", pool->description);
data->conn.connect(
pool->db.c_str(),
pool->server.c_str(),
pool->user.c_str(),
pool->password.c_str(),
pool->port,
pool->socket.c_str(),
pool->ssl_ca.c_str(),
pool->ssl_cert.c_str(),
pool->ssl_key.c_str(),
pool->connect_timeout,
pool->rw_timeout,
pool->enable_local_infile,
pool->opt_reconnect);
}
}
bool Pool::Entry::tryForceConnected() const
{
auto * const mysql_driver = data->conn.getDriver();
const auto prev_connection_id = mysql_thread_id(mysql_driver);
pool->logger.trace("Entry(connection %lu): sending PING to check if it is alive.", prev_connection_id);
if (data->conn.ping()) /// Attempts to reestablish lost connection
{
const auto current_connection_id = mysql_thread_id(mysql_driver);
if (prev_connection_id != current_connection_id)
{
pool->logger.debug("Entry(connection %lu): Reconnected to MySQL server. Connection id changed: %lu -> %lu",
current_connection_id, prev_connection_id, current_connection_id);
}
pool->logger.trace("Entry(connection %lu): PING ok.", current_connection_id);
return true;
}
pool->logger.trace("Entry(connection %lu): PING failed.", prev_connection_id);
return false;
}
void Pool::initialize()
{
if (!initialized)
{
description = db + "@" + server + ":" + std::to_string(port) + " as user " + user;
for (unsigned i = 0; i < default_connections; ++i)
allocConnection();
initialized = true;
}
}
Pool::Connection * Pool::allocConnection(bool dont_throw_if_failed_first_time)
{
std::unique_ptr<Connection> conn_ptr{new Connection};
try
{
logger.debug("Connecting to %s", description);
conn_ptr->conn.connect(
db.c_str(),
server.c_str(),
user.c_str(),
password.c_str(),
port,
socket.c_str(),
ssl_ca.c_str(),
ssl_cert.c_str(),
ssl_key.c_str(),
connect_timeout,
rw_timeout,
enable_local_infile,
opt_reconnect);
}
catch (mysqlxx::ConnectionFailed & e)
{
logger.error(e.what());
if ((!was_successful && !dont_throw_if_failed_first_time)
|| e.errnum() == ER_ACCESS_DENIED_ERROR
|| e.errnum() == ER_DBACCESS_DENIED_ERROR
|| e.errnum() == ER_BAD_DB_ERROR)
{
throw;
}
else
{
return nullptr;
}
}
connections.push_back(conn_ptr.get());
was_successful = true;
return conn_ptr.release();
}
}