Better way

This commit is contained in:
kssenii 2021-06-07 18:09:16 +00:00
parent ba56465a99
commit bd39c9fdd1
10 changed files with 73 additions and 50 deletions

View File

@ -105,19 +105,19 @@ void ODBCColumnsInfoHandler::handleRequest(HTTPServerRequest & request, HTTPServ
{
const bool external_table_functions_use_nulls = Poco::NumberParser::parseBool(params.get("external_table_functions_use_nulls", "false"));
auto connection = ODBCConnectionFactory::instance().get(
auto connection_holder = ODBCConnectionFactory::instance().get(
validateODBCConnectionString(connection_string),
getContext()->getSettingsRef().odbc_bridge_connection_pool_size);
nanodbc::catalog catalog(connection->get());
std::string catalog_name;
/// In XDBC tables it is allowed to pass either database_name or schema_name in table definion, but not both of them.
/// They both are passed as 'schema' parameter in request URL, so it is not clear whether it is database_name or schema_name passed.
/// If it is schema_name then we know that database is added in odbc.ini. But if we have database_name as 'schema',
/// it is not guaranteed. For nanodbc database_name must be either in odbc.ini or passed as catalog_name.
auto get_columns = [&]()
auto get_columns = [&](nanodbc::connection & connection)
{
nanodbc::catalog catalog(connection);
std::string catalog_name;
nanodbc::catalog::tables tables = catalog.find_tables(table_name, /* type = */ "", /* schema = */ "", /* catalog = */ schema_name);
if (tables.next())
{
@ -137,7 +137,9 @@ void ODBCColumnsInfoHandler::handleRequest(HTTPServerRequest & request, HTTPServ
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Table {} not found", schema_name.empty() ? table_name : schema_name + '.' + table_name);
};
nanodbc::catalog::columns columns_definition = get_columns();
nanodbc::catalog::columns columns_definition = execute<nanodbc::catalog::columns>(
std::move(connection_holder),
[&](nanodbc::connection & connection) { return get_columns(connection); });
NamesAndTypesList columns;
while (columns_definition.next())

View File

@ -46,7 +46,7 @@ void IdentifierQuoteHandler::handleRequest(HTTPServerRequest & request, HTTPServ
validateODBCConnectionString(connection_string),
getContext()->getSettingsRef().odbc_bridge_connection_pool_size);
auto identifier = getIdentifierQuote(connection->get());
auto identifier = getIdentifierQuote(std::move(connection));
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
try

View File

@ -108,7 +108,7 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
try
{
auto connection = ODBCConnectionFactory::instance().get(
auto connection_handler = ODBCConnectionFactory::instance().get(
validateODBCConnectionString(connection_string),
getContext()->getSettingsRef().odbc_bridge_connection_pool_size);
@ -130,12 +130,12 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
auto quoting_style = IdentifierQuotingStyle::None;
#if USE_ODBC
quoting_style = getQuotingStyle(connection->get());
quoting_style = getQuotingStyle(connection_handler);
#endif
auto & read_buf = request.getStream();
auto input_format = FormatFactory::instance().getInput(format, read_buf, *sample_block, getContext(), max_block_size);
auto input_stream = std::make_shared<InputStreamFromInputFormat>(input_format);
ODBCBlockOutputStream output_stream(std::move(connection), db_name, table_name, *sample_block, getContext(), quoting_style);
ODBCBlockOutputStream output_stream(std::move(connection_handler), db_name, table_name, *sample_block, getContext(), quoting_style);
copyData(*input_stream, output_stream);
writeStringBinary("Ok.", out);
}
@ -145,7 +145,7 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
LOG_TRACE(log, "Query: {}", query);
BlockOutputStreamPtr writer = FormatFactory::instance().getOutputStreamParallelIfPossible(format, out, *sample_block, getContext());
ODBCBlockInputStream inp(std::move(connection), query, *sample_block, max_block_size);
ODBCBlockInputStream inp(std::move(connection_handler), query, *sample_block, max_block_size);
copyData(inp, *writer);
}
}

View File

@ -21,13 +21,14 @@ namespace ErrorCodes
ODBCBlockInputStream::ODBCBlockInputStream(
nanodbc::ConnectionHolderPtr connection, const std::string & query_str, const Block & sample_block, const UInt64 max_block_size_)
nanodbc::ConnectionHolderPtr connection_holder, const std::string & query_str, const Block & sample_block, const UInt64 max_block_size_)
: log(&Poco::Logger::get("ODBCBlockInputStream"))
, max_block_size{max_block_size_}
, query(query_str)
{
description.init(sample_block);
result = execute(connection->get(), NANODBC_TEXT(query));
result = execute<nanodbc::result>(connection_holder,
[&](nanodbc::connection & connection) { return execute(connection, query); });
}

View File

@ -40,14 +40,14 @@ namespace
}
}
ODBCBlockOutputStream::ODBCBlockOutputStream(nanodbc::ConnectionHolderPtr connection_,
ODBCBlockOutputStream::ODBCBlockOutputStream(nanodbc::ConnectionHolderPtr connection_holder_,
const std::string & remote_database_name_,
const std::string & remote_table_name_,
const Block & sample_block_,
ContextPtr local_context_,
IdentifierQuotingStyle quoting_)
: log(&Poco::Logger::get("ODBCBlockOutputStream"))
, connection(std::move(connection_))
, connection_holder(std::move(connection_holder_))
, db_name(remote_database_name_)
, table_name(remote_table_name_)
, sample_block(sample_block_)
@ -69,7 +69,8 @@ void ODBCBlockOutputStream::write(const Block & block)
writer->write(block);
std::string query = getInsertQuery(db_name, table_name, block.getColumnsWithTypeAndName(), quoting) + values_buf.str();
execute(connection->get(), query);
execute<void>(connection_holder,
[&](nanodbc::connection & connection) { execute(connection, query); });
}
}

View File

@ -29,7 +29,7 @@ public:
private:
Poco::Logger * log;
nanodbc::ConnectionHolderPtr connection;
nanodbc::ConnectionHolderPtr connection_holder;
std::string db_name;
std::string table_name;
Block sample_block;

View File

@ -21,10 +21,20 @@ using ConnectionPtr = std::unique_ptr<nanodbc::connection>;
using Pool = BorrowedObjectPool<ConnectionPtr>;
using PoolPtr = std::shared_ptr<Pool>;
static constexpr inline auto ODBC_CONNECT_TIMEOUT = 100;
class ConnectionHolder
{
public:
ConnectionHolder(PoolPtr pool_, ConnectionPtr connection_) : pool(pool_), connection(std::move(connection_)) {}
ConnectionHolder(PoolPtr pool_,
ConnectionPtr connection_,
const String & connection_string_)
: pool(pool_)
, connection(std::move(connection_))
, connection_string(connection_string_)
{
}
ConnectionHolder(const ConnectionHolder & other) = delete;
@ -39,12 +49,19 @@ public:
return *connection;
}
void updateConnection()
{
connection = std::make_unique<nanodbc::connection>(connection_string, ODBC_CONNECT_TIMEOUT);
}
private:
PoolPtr pool;
ConnectionPtr connection;
const String & connection_string;
};
using ConnectionHolderPtr = std::unique_ptr<ConnectionHolder>;
using ConnectionHolderPtr = std::shared_ptr<ConnectionHolder>;
}
@ -53,7 +70,26 @@ namespace DB
static constexpr inline auto ODBC_CONNECT_TIMEOUT = 100;
static constexpr inline auto ODBC_POOL_WAIT_TIMEOUT = 10000;
static constexpr auto ODBC_CHECK_CONNECTION_QUERY = "SELECT 1";
template <typename T>
T execute(nanodbc::ConnectionHolderPtr connection_holder, std::function<T(nanodbc::connection &)> query_func)
{
try
{
return query_func(connection_holder->get());
}
catch (const nanodbc::database_error & e)
{
/// SQLState, connection related errors start with 08S0.
if (e.state().starts_with("08S0"))
{
connection_holder->updateConnection();
return query_func(connection_holder->get());
}
throw;
}
}
class ODBCConnectionFactory final : private boost::noncopyable
{
@ -64,22 +100,6 @@ public:
return ret;
}
/// this check is performed only on the connection which was already successfully open before.
static bool needReconnect(nanodbc::connection & connection)
{
try
{
/// just_execute - execution without preparing any result object.
just_execute(connection, ODBC_CHECK_CONNECTION_QUERY);
}
catch (const nanodbc::database_error &)
{
return true;
}
return false;
}
nanodbc::ConnectionHolderPtr get(const std::string & connection_string, size_t pool_size)
{
std::lock_guard lock(mutex);
@ -97,10 +117,8 @@ public:
try
{
if (!connection || needReconnect(*connection))
{
if (!connection)
connection = std::make_unique<nanodbc::connection>(connection_string, ODBC_CONNECT_TIMEOUT);
}
}
catch (...)
{
@ -108,7 +126,7 @@ public:
throw;
}
return std::make_unique<nanodbc::ConnectionHolder>(factory[connection_string], std::move(connection));
return std::make_unique<nanodbc::ConnectionHolder>(factory[connection_string], std::move(connection), connection_string);
}
private:

View File

@ -18,9 +18,10 @@ namespace DB
{
namespace
{
bool isSchemaAllowed(nanodbc::connection & connection)
bool isSchemaAllowed(nanodbc::ConnectionHolderPtr connection_holder)
{
uint32_t result = connection.get_info<uint32_t>(SQL_SCHEMA_USAGE);
uint32_t result = execute<uint32_t>(connection_holder,
[&](nanodbc::connection & connection) { return connection.get_info<uint32_t>(SQL_SCHEMA_USAGE); });
return result != 0;
}
}
@ -53,7 +54,7 @@ void SchemaAllowedHandler::handleRequest(HTTPServerRequest & request, HTTPServer
validateODBCConnectionString(connection_string),
getContext()->getSettingsRef().odbc_bridge_connection_pool_size);
bool result = isSchemaAllowed(connection->get());
bool result = isSchemaAllowed(std::move(connection));
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
try

View File

@ -16,12 +16,13 @@ namespace ErrorCodes
}
std::string getIdentifierQuote(nanodbc::connection & connection)
std::string getIdentifierQuote(nanodbc::ConnectionHolderPtr connection_holder)
{
std::string quote;
try
{
quote = connection.get_info<std::string>(SQL_IDENTIFIER_QUOTE_CHAR);
quote = execute<std::string>(connection_holder,
[&](nanodbc::connection & connection) { return connection.get_info<std::string>(SQL_IDENTIFIER_QUOTE_CHAR); });
}
catch (...)
{
@ -33,7 +34,7 @@ std::string getIdentifierQuote(nanodbc::connection & connection)
}
IdentifierQuotingStyle getQuotingStyle(nanodbc::connection & connection)
IdentifierQuotingStyle getQuotingStyle(nanodbc::ConnectionHolderPtr connection)
{
auto identifier_quote = getIdentifierQuote(connection);
if (identifier_quote.length() == 0)

View File

@ -6,15 +6,14 @@
#include <Poco/Logger.h>
#include <Poco/Net/HTTPRequestHandler.h>
#include <Parsers/IdentifierQuotingStyle.h>
#include <nanodbc/nanodbc.h>
#include "ODBCConnectionFactory.h"
namespace DB
{
std::string getIdentifierQuote(nanodbc::connection & connection);
IdentifierQuotingStyle getQuotingStyle(nanodbc::connection & connection);
std::string getIdentifierQuote(nanodbc::ConnectionHolderPtr connection_holder);
IdentifierQuotingStyle getQuotingStyle(nanodbc::ConnectionHolderPtr connection);
}