#include #include #include #include #include #include #include "PostgreSQLHandler.h" #include #include #include #include #include #include #include #if USE_SSL # include # include #endif namespace DB { namespace ErrorCodes { extern const int SYNTAX_ERROR; } PostgreSQLHandler::PostgreSQLHandler( const Poco::Net::StreamSocket & socket_, IServer & server_, TCPServer & tcp_server_, bool ssl_enabled_, Int32 connection_id_, std::vector> & auth_methods_, const ProfileEvents::Event & read_event_, const ProfileEvents::Event & write_event_) : Poco::Net::TCPServerConnection(socket_) , server(server_) , tcp_server(tcp_server_) , ssl_enabled(ssl_enabled_) , connection_id(connection_id_) , read_event(read_event_) , write_event(write_event_) , authentication_manager(auth_methods_) { changeIO(socket()); } void PostgreSQLHandler::changeIO(Poco::Net::StreamSocket & socket) { in = std::make_shared(socket, read_event); out = std::make_shared(socket, write_event); message_transport = std::make_shared(in.get(), out.get()); } void PostgreSQLHandler::run() { setThreadName("PostgresHandler"); ThreadStatus thread_status; session = std::make_unique(server.context(), ClientInfo::Interface::POSTGRESQL); SCOPE_EXIT({ session.reset(); }); session->setClientConnectionId(connection_id); try { if (!startup()) return; while (tcp_server.isOpen()) { message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true); constexpr size_t connection_check_timeout = 1; // 1 second while (!in->poll(1000000 * connection_check_timeout)) if (!tcp_server.isOpen()) return; PostgreSQLProtocol::Messaging::FrontMessageType message_type = message_transport->receiveMessageType(); if (!tcp_server.isOpen()) return; switch (message_type) { case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY: processQuery(); break; case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE: LOG_DEBUG(log, "Client closed the connection"); return; case PostgreSQLProtocol::Messaging::FrontMessageType::PARSE: case PostgreSQLProtocol::Messaging::FrontMessageType::BIND: case PostgreSQLProtocol::Messaging::FrontMessageType::DESCRIBE: case PostgreSQLProtocol::Messaging::FrontMessageType::SYNC: case PostgreSQLProtocol::Messaging::FrontMessageType::FLUSH: case PostgreSQLProtocol::Messaging::FrontMessageType::CLOSE: message_transport->send( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "0A000", "ClickHouse doesn't support extended query mechanism"), true); LOG_ERROR(log, "Client tried to access via extended query protocol"); message_transport->dropMessage(); break; default: message_transport->send( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "0A000", "Command is not supported"), true); LOG_ERROR(log, "Command is not supported. Command code {:d}", static_cast(message_type)); message_transport->dropMessage(); } } } catch (const Poco::Exception &exc) { log->log(exc); } } bool PostgreSQLHandler::startup() { Int32 payload_size; Int32 info; establishSecureConnection(payload_size, info); if (static_cast(info) == PostgreSQLProtocol::Messaging::FrontMessageType::CANCEL_REQUEST) { LOG_DEBUG(log, "Client issued request canceling"); cancelRequest(); return false; } std::unique_ptr start_up_msg = receiveStartupMessage(payload_size); const auto & user_name = start_up_msg->user; authentication_manager.authenticate(user_name, *session, *message_transport, socket().peerAddress()); try { session->makeSessionContext(); session->sessionContext()->setDefaultFormat("PostgreSQLWire"); if (!start_up_msg->database.empty()) session->sessionContext()->setCurrentDatabase(start_up_msg->database); } catch (const Exception & exc) { message_transport->send( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "XX000", exc.message()), true); throw; } sendParameterStatusData(*start_up_msg); message_transport->send( PostgreSQLProtocol::Messaging::BackendKeyData(connection_id, secret_key), true); LOG_DEBUG(log, "Successfully finished Startup stage"); return true; } void PostgreSQLHandler::establishSecureConnection(Int32 & payload_size, Int32 & info) { bool was_encryption_req = true; readBinaryBigEndian(payload_size, *in); readBinaryBigEndian(info, *in); switch (static_cast(info)) { case PostgreSQLProtocol::Messaging::FrontMessageType::SSL_REQUEST: LOG_DEBUG(log, "Client requested SSL"); if (ssl_enabled) makeSecureConnectionSSL(); else message_transport->send('N', true); break; case PostgreSQLProtocol::Messaging::FrontMessageType::GSSENC_REQUEST: LOG_DEBUG(log, "Client requested GSSENC"); message_transport->send('N', true); break; default: was_encryption_req = false; } if (was_encryption_req) { readBinaryBigEndian(payload_size, *in); readBinaryBigEndian(info, *in); } } #if USE_SSL void PostgreSQLHandler::makeSecureConnectionSSL() { message_transport->send('S'); ss = std::make_shared( Poco::Net::SecureStreamSocket::attach(socket(), Poco::Net::SSLManager::instance().defaultServerContext())); changeIO(*ss); } #else void PostgreSQLHandler::makeSecureConnectionSSL() {} #endif void PostgreSQLHandler::sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message) { std::unordered_map & parameters = start_up_message.parameters; if (parameters.find("application_name") != parameters.end()) message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("application_name", parameters["application_name"])); if (parameters.find("client_encoding") != parameters.end()) message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("client_encoding", parameters["client_encoding"])); else message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("client_encoding", "UTF8")); message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("server_version", VERSION_STRING)); message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("server_encoding", "UTF8")); message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("DateStyle", "ISO")); message_transport->flush(); } void PostgreSQLHandler::cancelRequest() { std::unique_ptr msg = message_transport->receiveWithPayloadSize(8); String query = fmt::format("KILL QUERY WHERE query_id = 'postgres:{:d}:{:d}'", msg->process_id, msg->secret_key); ReadBufferFromString replacement(query); auto query_context = session->makeQueryContext(); query_context->setCurrentQueryId(""); executeQuery(replacement, *out, true, query_context, {}); } inline std::unique_ptr PostgreSQLHandler::receiveStartupMessage(int payload_size) { std::unique_ptr message; try { message = message_transport->receiveWithPayloadSize(payload_size - 8); } catch (const Exception &) { message_transport->send( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "08P01", "Can't correctly handle Startup message"), true); throw; } LOG_DEBUG(log, "Successfully received Startup message"); return message; } void PostgreSQLHandler::processQuery() { try { std::unique_ptr query = message_transport->receive(); if (isEmptyQuery(query->query)) { message_transport->send(PostgreSQLProtocol::Messaging::EmptyQueryResponse()); return; } bool psycopg2_cond = query->query == "BEGIN" || query->query == "COMMIT"; // psycopg2 starts and ends queries with BEGIN/COMMIT commands bool jdbc_cond = query->query.find("SET extra_float_digits") != String::npos || query->query.find("SET application_name") != String::npos; // jdbc starts with setting this parameter if (psycopg2_cond || jdbc_cond) { message_transport->send( PostgreSQLProtocol::Messaging::CommandComplete( PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(query->query), 0)); return; } const auto & settings = session->sessionContext()->getSettingsRef(); std::vector queries; auto parse_res = splitMultipartQuery(query->query, queries, settings.max_query_size, settings.max_parser_depth, settings.max_parser_backtracks, settings.allow_settings_after_format_in_insert); if (!parse_res.second) throw Exception(ErrorCodes::SYNTAX_ERROR, "Cannot parse and execute the following part of query: {}", String(parse_res.first)); pcg64_fast gen{randomSeed()}; std::uniform_int_distribution dis(0, INT32_MAX); for (const auto & spl_query : queries) { secret_key = dis(gen); auto query_context = session->makeQueryContext(); query_context->setCurrentQueryId(fmt::format("postgres:{:d}:{:d}", connection_id, secret_key)); CurrentThread::QueryScope query_scope{query_context}; ReadBufferFromString read_buf(spl_query); executeQuery(read_buf, *out, false, query_context, {}); PostgreSQLProtocol::Messaging::CommandComplete::Command command = PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(spl_query); message_transport->send(PostgreSQLProtocol::Messaging::CommandComplete(command, 0), true); } } catch (const Exception & e) { message_transport->send( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "2F000", "Query execution failed.\n" + e.displayText()), true); throw; } } bool PostgreSQLHandler::isEmptyQuery(const String & query) { if (query.empty()) return true; /// golang driver pgx sends ";" if (query == ";") return true; Poco::RegularExpression regex(R"(\A\s*\z)"); return regex.match(query); } }