From cb845731b06dff50f9862c76215d232ed1f66600 Mon Sep 17 00:00:00 2001 From: kssenii Date: Wed, 31 Mar 2021 12:41:12 +0000 Subject: [PATCH] Add connection pools --- programs/odbc-bridge/ColumnInfoHandler.cpp | 6 +- .../odbc-bridge/IdentifierQuoteHandler.cpp | 7 +- programs/odbc-bridge/MainHandler.cpp | 9 +- programs/odbc-bridge/ODBCConnectionFactory.h | 83 +++++++++++++++++++ programs/odbc-bridge/SchemaAllowedHandler.cpp | 8 +- .../integration/test_odbc_interaction/test.py | 2 +- 6 files changed, 99 insertions(+), 16 deletions(-) create mode 100644 programs/odbc-bridge/ODBCConnectionFactory.h diff --git a/programs/odbc-bridge/ColumnInfoHandler.cpp b/programs/odbc-bridge/ColumnInfoHandler.cpp index cc184a5e703..16740b18357 100644 --- a/programs/odbc-bridge/ColumnInfoHandler.cpp +++ b/programs/odbc-bridge/ColumnInfoHandler.cpp @@ -17,8 +17,8 @@ #include #include "getIdentifierQuote.h" #include "validateODBCConnectionString.h" +#include "ODBCConnectionFactory.h" -#include #include #include @@ -105,8 +105,8 @@ void ODBCColumnsInfoHandler::handleRequest(HTTPServerRequest & request, HTTPServ { const bool external_table_functions_use_nulls = Poco::NumberParser::parseBool(params.get("external_table_functions_use_nulls", "false")); - nanodbc::connection connection(validateODBCConnectionString(connection_string)); - nanodbc::catalog catalog(connection); + auto connection = ODBCConnectionFactory::instance().get(validateODBCConnectionString(connection_string)); + nanodbc::catalog catalog(*connection); std::string catalog_name; /// In XDBC tables it is allowed to pass either database_name or schema_name in table definion, but not both of them. diff --git a/programs/odbc-bridge/IdentifierQuoteHandler.cpp b/programs/odbc-bridge/IdentifierQuoteHandler.cpp index a51470b5489..e0626504b31 100644 --- a/programs/odbc-bridge/IdentifierQuoteHandler.cpp +++ b/programs/odbc-bridge/IdentifierQuoteHandler.cpp @@ -14,7 +14,8 @@ #include #include "getIdentifierQuote.h" #include "validateODBCConnectionString.h" -#include +#include "ODBCConnectionFactory.h" + namespace DB { @@ -40,8 +41,8 @@ void IdentifierQuoteHandler::handleRequest(HTTPServerRequest & request, HTTPServ try { std::string connection_string = params.get("connection_string"); - nanodbc::connection connection(validateODBCConnectionString(connection_string)); - auto identifier = getIdentifierQuote(connection); + auto connection = ODBCConnectionFactory::instance().get(validateODBCConnectionString(connection_string)); + auto identifier = getIdentifierQuote(*connection); WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout); try diff --git a/programs/odbc-bridge/MainHandler.cpp b/programs/odbc-bridge/MainHandler.cpp index 5ea691e4301..74909458504 100644 --- a/programs/odbc-bridge/MainHandler.cpp +++ b/programs/odbc-bridge/MainHandler.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "ODBCConnectionFactory.h" #include #include @@ -104,8 +105,8 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse std::string connection_string = params.get("connection_string"); LOG_TRACE(log, "Connection string: '{}'", connection_string); - nanodbc::connection connection(connection_string); + auto connection = ODBCConnectionFactory::instance().get(validateODBCConnectionString(connection_string)); WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout); try @@ -128,12 +129,12 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse auto quoting_style = IdentifierQuotingStyle::None; #if USE_ODBC - quoting_style = getQuotingStyle(connection); + quoting_style = getQuotingStyle(*connection); #endif auto & read_buf = request.getStream(); auto input_format = FormatFactory::instance().getInput(format, read_buf, *sample_block, context, max_block_size); auto input_stream = std::make_shared(input_format); - ODBCBlockOutputStream output_stream(connection, db_name, table_name, *sample_block, context, quoting_style); + ODBCBlockOutputStream output_stream(*connection, db_name, table_name, *sample_block, context, quoting_style); copyData(*input_stream, output_stream); writeStringBinary("Ok.", out); } @@ -143,7 +144,7 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse LOG_TRACE(log, "Query: {}", query); BlockOutputStreamPtr writer = FormatFactory::instance().getOutputStreamParallelIfPossible(format, out, *sample_block, context); - ODBCBlockInputStream inp(connection, query, *sample_block, max_block_size); + ODBCBlockInputStream inp(*connection, query, *sample_block, max_block_size); copyData(inp, *writer); } } diff --git a/programs/odbc-bridge/ODBCConnectionFactory.h b/programs/odbc-bridge/ODBCConnectionFactory.h new file mode 100644 index 00000000000..80b55a5d1eb --- /dev/null +++ b/programs/odbc-bridge/ODBCConnectionFactory.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include +#include + + +namespace nanodbc +{ + +static constexpr inline auto ODBC_CONNECT_TIMEOUT = 100; + +using ConnectionPtr = std::shared_ptr; +using Pool = BorrowedObjectPool; +using PoolPtr = std::shared_ptr; + +class ConnectionHolder +{ + +public: + ConnectionHolder(const std::string & connection_string_, PoolPtr pool_) : connection_string(connection_string_), pool(pool_) {} + + ~ConnectionHolder() + { + pool->returnObject(std::move(connection)); + } + + nanodbc::connection & operator*() + { + if (!connection || !connection->connected()) + { + pool->borrowObject(connection, [&]() + { + return std::make_shared(connection_string, ODBC_CONNECT_TIMEOUT); + }); + } + + return *connection; + } + +private: + std::string connection_string; + PoolPtr pool; + ConnectionPtr connection; +}; + +} + + +namespace DB +{ + +static constexpr inline auto ODBC_DEFAULT_POOL_SIZE = 16; + +class ODBCConnectionFactory final : private boost::noncopyable +{ +public: + static ODBCConnectionFactory & instance() + { + static ODBCConnectionFactory ret; + return ret; + } + + nanodbc::ConnectionHolder get(const std::string & connection_string, size_t pool_size = ODBC_DEFAULT_POOL_SIZE) + { + std::lock_guard lock(mutex); + + if (!factory.count(connection_string)) + factory.emplace(std::make_pair(connection_string, std::make_shared(pool_size))); + + return nanodbc::ConnectionHolder(connection_string, factory[connection_string]); + } + +private: + /// [connection_string] -> [connection_pool] + using PoolFactory = std::unordered_map; + PoolFactory factory; + std::mutex mutex; +}; + +} diff --git a/programs/odbc-bridge/SchemaAllowedHandler.cpp b/programs/odbc-bridge/SchemaAllowedHandler.cpp index 0f645ee8710..bb95b80e225 100644 --- a/programs/odbc-bridge/SchemaAllowedHandler.cpp +++ b/programs/odbc-bridge/SchemaAllowedHandler.cpp @@ -9,8 +9,7 @@ #include #include #include "validateODBCConnectionString.h" - -#include +#include "ODBCConnectionFactory.h" #include #include @@ -49,9 +48,8 @@ void SchemaAllowedHandler::handleRequest(HTTPServerRequest & request, HTTPServer try { std::string connection_string = params.get("connection_string"); - nanodbc::connection connection(validateODBCConnectionString(connection_string)); - - bool result = isSchemaAllowed(connection); + auto connection = ODBCConnectionFactory::instance().get(validateODBCConnectionString(connection_string)); + bool result = isSchemaAllowed(*connection); WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout); try diff --git a/tests/integration/test_odbc_interaction/test.py b/tests/integration/test_odbc_interaction/test.py index 61b2e56ea9e..8b1b28a9a59 100644 --- a/tests/integration/test_odbc_interaction/test.py +++ b/tests/integration/test_odbc_interaction/test.py @@ -269,7 +269,7 @@ def test_sqlite_odbc_cached_dictionary(started_cluster): node1.exec_in_container(["bash", "-c", "chmod a+rw /tmp"], privileged=True, user='root') node1.exec_in_container(["bash", "-c", "chmod a+rw {}".format(sqlite_db)], privileged=True, user='root') - node1.query("insert into table function odbc('DSN={};', '', 't3') values (200, 2, 7)".format( + node1.query("insert into table function odbc('DSN={};ReadOnly=0', '', 't3') values (200, 2, 7)".format( node1.odbc_drivers["SQLite3"]["DSN"])) assert node1.query("select dictGetUInt8('sqlite3_odbc_cached', 'Z', toUInt64(200))") == "7\n" # new value