From fb0b445ee0e9f11e0cc134662284cefa3ed94518 Mon Sep 17 00:00:00 2001 From: kssenii Date: Tue, 17 Aug 2021 22:59:51 +0300 Subject: [PATCH] Lets add LocalConnection (a start) --- programs/client/Client.cpp | 379 +------------------- programs/client/Client.h | 37 +- programs/local/LocalServer.cpp | 23 +- programs/local/LocalServer.h | 7 + src/Client/ClientBase.cpp | 376 ++++++++++++++++++- src/Client/ClientBase.h | 36 ++ src/Client/Connection.cpp | 13 +- src/Client/Connection.h | 135 +++---- src/Client/IServerConnection.h | 124 +++++++ src/Client/LocalConnection.cpp | 215 +++++++++++ src/Client/LocalConnection.h | 160 +++++++++ src/Client/Suggest.cpp | 2 +- src/DataStreams/RemoteBlockOutputStream.cpp | 8 +- 13 files changed, 992 insertions(+), 523 deletions(-) create mode 100644 src/Client/IServerConnection.h create mode 100644 src/Client/LocalConnection.cpp create mode 100644 src/Client/LocalConnection.h diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index af1b534987f..c883885711a 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -51,7 +51,6 @@ #include #include -#include #include #include @@ -65,7 +64,6 @@ #include #include -#include #include #include @@ -289,7 +287,7 @@ static bool queryHasWithClause(const IAST * ast) std::vector Client::loadWarningMessages() { std::vector messages; - connection->sendQuery(connection_parameters.timeouts, "SELECT message FROM system.warnings", "" /* query_id */, QueryProcessingStage::Complete); + connection->sendQuery(connection_parameters.timeouts, "SELECT message FROM system.warnings", "" /* query_id */, QueryProcessingStage::Complete, nullptr, nullptr, false); while (true) { Packet packet = connection->receivePacket(); @@ -930,21 +928,6 @@ bool Client::processWithFuzzing(const String & full_query) } -/// Convert external tables to ExternalTableData and send them using the connection. -void Client::sendExternalTables(ASTPtr parsed_query) -{ - const auto * select = parsed_query->as(); - if (!select && !external_tables.empty()) - throw Exception("External tables could be sent only with select query", ErrorCodes::BAD_ARGUMENTS); - - std::vector data; - for (auto & table : external_tables) - data.emplace_back(table.getData(global_context)); - - connection->sendExternalTablesData(data); -} - - void Client::processInsertQuery(const String & query_to_execute, ASTPtr parsed_query) { /// Process the query that requires transferring data blocks to the server. @@ -976,61 +959,6 @@ void Client::processInsertQuery(const String & query_to_execute, ASTPtr parsed_q } -void Client::processOrdinaryQuery(const String & query_to_execute, ASTPtr parsed_query) -{ - /// Rewrite query only when we have query parameters. - /// Note that if query is rewritten, comments in query are lost. - /// But the user often wants to see comments in server logs, query log, processlist, etc. - auto query = query_to_execute; - if (!query_parameters.empty()) - { - /// Replace ASTQueryParameter with ASTLiteral for prepared statements. - ReplaceQueryParameterVisitor visitor(query_parameters); - visitor.visit(parsed_query); - - /// Get new query after substitutions. Note that it cannot be done for INSERT query with embedded data. - query = serializeAST(*parsed_query); - } - - int retries_left = 10; - for (;;) - { - assert(retries_left > 0); - - try - { - connection->sendQuery( - connection_parameters.timeouts, - query, - global_context->getCurrentQueryId(), - query_processing_stage, - &global_context->getSettingsRef(), - &global_context->getClientInfo(), - true); - - sendExternalTables(parsed_query); - receiveResult(parsed_query); - - break; - } - catch (const Exception & e) - { - /// Retry when the server said "Client should retry" and no rows - /// has been received yet. - if (processed_rows == 0 && e.code() == ErrorCodes::DEADLOCK_AVOIDED && --retries_left) - { - std::cerr << "Got a transient error from the server, will" - << " retry (" << retries_left << " retries left)"; - } - else - { - throw; - } - } - } -} - - void Client::executeSingleQuery(const String & query_to_execute, ASTPtr parsed_query) { client_exception.reset(); @@ -1196,7 +1124,7 @@ void Client::sendDataFrom(ReadBuffer & buf, Block & sample, const ColumnsDescrip receiveLogs(parsed_query); /// Check if server send Exception packet - auto packet_type = connection->checkPacket(); + auto packet_type = connection->checkPacket(/* timeout_milliseconds */0); if (packet_type && *packet_type == Protocol::Server::Exception) { /* @@ -1209,137 +1137,12 @@ void Client::sendDataFrom(ReadBuffer & buf, Block & sample, const ColumnsDescrip if (block) { - connection->sendData(block); + connection->sendData(block, /* name */"", /* scalar */false); processed_rows += block.rows(); } } - connection->sendData({}); -} - - -/// Receives and processes packets coming from server. -/// Also checks if query execution should be cancelled. -void Client::receiveResult(ASTPtr parsed_query) -{ - InterruptListener interrupt_listener; - bool cancelled = false; - - // TODO: get the poll_interval from commandline. - const auto receive_timeout = connection_parameters.timeouts.receive_timeout; - constexpr size_t default_poll_interval = 1000000; /// in microseconds - constexpr size_t min_poll_interval = 5000; /// in microseconds - const size_t poll_interval - = std::max(min_poll_interval, std::min(receive_timeout.totalMicroseconds(), default_poll_interval)); - - while (true) - { - Stopwatch receive_watch(CLOCK_MONOTONIC_COARSE); - - while (true) - { - /// Has the Ctrl+C been pressed and thus the query should be cancelled? - /// If this is the case, inform the server about it and receive the remaining packets - /// to avoid losing sync. - if (!cancelled) - { - auto cancel_query = [&] { - connection->sendCancel(); - cancelled = true; - if (is_interactive) - { - progress_indication.clearProgressOutput(); - std::cout << "Cancelling query." << std::endl; - } - - /// Pressing Ctrl+C twice results in shut down. - interrupt_listener.unblock(); - }; - - if (interrupt_listener.check()) - { - cancel_query(); - } - else - { - double elapsed = receive_watch.elapsedSeconds(); - if (elapsed > receive_timeout.totalSeconds()) - { - std::cout << "Timeout exceeded while receiving data from server." - << " Waited for " << static_cast(elapsed) << " seconds," - << " timeout is " << receive_timeout.totalSeconds() << " seconds." << std::endl; - - cancel_query(); - } - } - } - - /// Poll for changes after a cancellation check, otherwise it never reached - /// because of progress updates from server. - if (connection->poll(poll_interval)) - break; - } - - if (!receiveAndProcessPacket(parsed_query, cancelled)) - break; - } - - if (cancelled && is_interactive) - std::cout << "Query was cancelled." << std::endl; -} - - -/// Receive a part of the result, or progress info or an exception and process it. -/// Returns true if one should continue receiving packets. -/// Output of result is suppressed if query was cancelled. -bool Client::receiveAndProcessPacket(ASTPtr parsed_query, bool cancelled) -{ - Packet packet = connection->receivePacket(); - - switch (packet.type) - { - case Protocol::Server::PartUUIDs: - return true; - - case Protocol::Server::Data: - if (!cancelled) - onData(packet.block, parsed_query); - return true; - - case Protocol::Server::Progress: - onProgress(packet.progress); - return true; - - case Protocol::Server::ProfileInfo: - onProfileInfo(packet.profile_info); - return true; - - case Protocol::Server::Totals: - if (!cancelled) - onTotals(packet.block, parsed_query); - return true; - - case Protocol::Server::Extremes: - if (!cancelled) - onExtremes(packet.block, parsed_query); - return true; - - case Protocol::Server::Exception: - onReceiveExceptionFromServer(std::move(packet.exception)); - return false; - - case Protocol::Server::Log: - onLogData(packet.block); - return true; - - case Protocol::Server::EndOfStream: - onEndOfStream(); - return false; - - default: - throw Exception( - ErrorCodes::UNKNOWN_PACKET_FROM_SERVER, "Unknown packet {} from server {}", packet.type, connection->getDescription()); - } + connection->sendData({}, "", false); } @@ -1412,192 +1215,22 @@ bool Client::receiveEndOfQuery() /// Process Log packets, used when inserting data by blocks void Client::receiveLogs(ASTPtr parsed_query) { - auto packet_type = connection->checkPacket(); + auto packet_type = connection->checkPacket(0); while (packet_type && *packet_type == Protocol::Server::Log) { receiveAndProcessPacket(parsed_query, false); - packet_type = connection->checkPacket(); + packet_type = connection->checkPacket(/* timeout_milliseconds */0); } } -void Client::initBlockOutputStream(const Block & block, ASTPtr parsed_query) -{ - if (!block_out_stream) - { - /// Ignore all results when fuzzing as they can be huge. - if (query_fuzzer_runs) - { - block_out_stream = std::make_shared(block); - return; - } - - WriteBuffer * out_buf = nullptr; - String pager = config().getString("pager", ""); - if (!pager.empty()) - { - signal(SIGPIPE, SIG_IGN); - pager_cmd = ShellCommand::execute(pager, true); - out_buf = &pager_cmd->in; - } - else - { - out_buf = &std_out; - } - - String current_format = format; - - /// The query can specify output format or output file. - /// FIXME: try to prettify this cast using `as<>()` - if (const auto * query_with_output = dynamic_cast(parsed_query.get())) - { - if (query_with_output->out_file) - { - const auto & out_file_node = query_with_output->out_file->as(); - const auto & out_file = out_file_node.value.safeGet(); - - out_file_buf = wrapWriteBufferWithCompressionMethod( - std::make_unique(out_file, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_EXCL | O_CREAT), - chooseCompressionMethod(out_file, ""), - /* compression level = */ 3 - ); - - // We are writing to file, so default format is the same as in non-interactive mode. - if (is_interactive && is_default_format) - current_format = "TabSeparated"; - } - if (query_with_output->format != nullptr) - { - if (has_vertical_output_suffix) - throw Exception("Output format already specified", ErrorCodes::CLIENT_OUTPUT_FORMAT_SPECIFIED); - const auto & id = query_with_output->format->as(); - current_format = id.name(); - } - } - - if (has_vertical_output_suffix) - current_format = "Vertical"; - - /// It is not clear how to write progress with parallel formatting. It may increase code complexity significantly. - if (!need_render_progress) - block_out_stream = global_context->getOutputStreamParallelIfPossible(current_format, out_file_buf ? *out_file_buf : *out_buf, block); - else - block_out_stream = global_context->getOutputStream(current_format, out_file_buf ? *out_file_buf : *out_buf, block); - - block_out_stream->writePrefix(); - } -} - - -void Client::initLogsOutputStream() -{ - if (!logs_out_stream) - { - WriteBuffer * wb = out_logs_buf.get(); - - if (!out_logs_buf) - { - if (server_logs_file.empty()) - { - /// Use stderr by default - out_logs_buf = std::make_unique(STDERR_FILENO); - wb = out_logs_buf.get(); - } - else if (server_logs_file == "-") - { - /// Use stdout if --server_logs_file=- specified - wb = &std_out; - } - else - { - out_logs_buf - = std::make_unique(server_logs_file, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_APPEND | O_CREAT); - wb = out_logs_buf.get(); - } - } - - logs_out_stream = std::make_shared(*wb, stdout_is_a_tty); - logs_out_stream->writePrefix(); - } -} - - -void Client::onData(Block & block, ASTPtr parsed_query) -{ - if (!block) - return; - - processed_rows += block.rows(); - - /// Even if all blocks are empty, we still need to initialize the output stream to write empty resultset. - initBlockOutputStream(block, parsed_query); - - /// The header block containing zero rows was used to initialize - /// block_out_stream, do not output it. - /// Also do not output too much data if we're fuzzing. - if (block.rows() == 0 || (query_fuzzer_runs != 0 && processed_rows >= 100)) - return; - - if (need_render_progress) - progress_indication.clearProgressOutput(); - - block_out_stream->write(block); - written_first_block = true; - - /// Received data block is immediately displayed to the user. - block_out_stream->flush(); - - /// Restore progress bar after data block. - if (need_render_progress) - progress_indication.writeProgress(); -} - - -void Client::onLogData(Block & block) -{ - initLogsOutputStream(); - progress_indication.clearProgressOutput(); - logs_out_stream->write(block); - logs_out_stream->flush(); -} - - -void Client::onTotals(Block & block, ASTPtr parsed_query) -{ - initBlockOutputStream(block, parsed_query); - block_out_stream->setTotals(block); -} - - -void Client::onExtremes(Block & block, ASTPtr parsed_query) -{ - initBlockOutputStream(block, parsed_query); - block_out_stream->setExtremes(block); -} - - void Client::writeFinalProgress() { progress_indication.writeFinalProgress(); } -void Client::onReceiveExceptionFromServer(std::unique_ptr && e) -{ - have_error = true; - server_exception = std::move(e); - resetOutput(); -} - - -void Client::onProfileInfo(const BlockStreamProfileInfo & profile_info) -{ - if (profile_info.hasAppliedLimit() && block_out_stream) - block_out_stream->setRowsBeforeLimit(profile_info.getRowsBeforeLimit()); -} - - void Client::readArguments(int argc, char ** argv, Arguments & common_arguments, std::vector & external_tables_arguments) { /** We allow different groups of arguments: diff --git a/programs/client/Client.h b/programs/client/Client.h index d94a89eedf6..7cda5ddc3a6 100644 --- a/programs/client/Client.h +++ b/programs/client/Client.h @@ -1,7 +1,6 @@ #pragma once #include -#include namespace DB @@ -42,21 +41,9 @@ protected: void processOptions(const OptionsDescription & options_description, const CommandLineOptions & options, const std::vector & external_tables_arguments) override; - - void processConfig() override; +void processConfig() override; private: - std::unique_ptr connection; /// Connection to DB. - ConnectionParameters connection_parameters; - - /// The last exception that was received from the server. Is used for the - /// return code in batch mode. - std::unique_ptr server_exception; - /// Likewise, the last exception that occurred on the client. - std::unique_ptr client_exception; - - String format; /// Query results output format. - bool is_default_format = true; /// false, if format is set in the config or command line. size_t format_max_block_size = 0; /// Max block size for console output. String insert_format; /// Format of INSERT data that is read from stdin in batch mode. size_t insert_format_max_block_size = 0; /// Max block size when reading INSERT data. @@ -65,42 +52,22 @@ private: UInt64 server_revision = 0; String server_version; - /// External tables info. - std::list external_tables; - - /// Dictionary with query parameters for prepared statements. - NameToNameMap query_parameters; - QueryProcessingStage::Enum query_processing_stage; String current_profile; - void connect(); + void connect() override; void printChangedSettings() const; - void sendExternalTables(ASTPtr parsed_query); void processInsertQuery(const String & query_to_execute, ASTPtr parsed_query); - void processOrdinaryQuery(const String & query_to_execute, ASTPtr parsed_query); void sendData(Block & sample, const ColumnsDescription & columns_description, ASTPtr parsed_query); void sendDataFrom(ReadBuffer & buf, Block & sample, const ColumnsDescription & columns_description, ASTPtr parsed_query); - void receiveResult(ASTPtr parsed_query); void receiveLogs(ASTPtr parsed_query); bool receiveEndOfQuery(); - bool receiveAndProcessPacket(ASTPtr parsed_query, bool cancelled); bool receiveSampleBlock(Block & out, ColumnsDescription & columns_description, ASTPtr parsed_query); - void initBlockOutputStream(const Block & block, ASTPtr parsed_query); - void initLogsOutputStream(); - - void onData(Block & block, ASTPtr parsed_query); - void onLogData(Block & block); - void onTotals(Block & block, ASTPtr parsed_query); - void onExtremes(Block & block, ASTPtr parsed_query); - void writeFinalProgress(); - void onReceiveExceptionFromServer(std::unique_ptr && e); - void onProfileInfo(const BlockStreamProfileInfo & profile_info); std::vector loadWarningMessages(); void reconnectIfNeeded() diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index fabe4c5af0a..a0d3b70ee39 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -257,10 +257,10 @@ void LocalServer::checkInterruptListener() } -void LocalServer::executeSingleQuery(const String & query_to_execute, ASTPtr /* parsed_query */) +void LocalServer::executeSingleQuery(const String & query_to_execute, ASTPtr parsed_query) { - ReadBufferFromString read_buf(query_to_execute); - WriteBufferFromFileDescriptor write_buf(STDOUT_FILENO); + // ReadBufferFromString read_buf(query_to_execute); + // WriteBufferFromFileDescriptor write_buf(STDOUT_FILENO); cancelled = false; @@ -291,8 +291,8 @@ void LocalServer::executeSingleQuery(const String & query_to_execute, ASTPtr /* }; } - if (is_interactive) - interrupt_listener.emplace(); + // if (is_interactive) + // interrupt_listener.emplace(); SCOPE_EXIT_SAFE({ if (interrupt_listener) @@ -310,7 +310,7 @@ void LocalServer::executeSingleQuery(const String & query_to_execute, ASTPtr /* try { - executeQuery(read_buf, write_buf, /* allow_into_outfile = */ true, query_context, {}, {}, flush_buffer_func); + processOrdinaryQuery(query_to_execute, parsed_query); } catch (const Exception & e) { @@ -473,6 +473,8 @@ try progress_indication.setFileProgressCallback(query_context); } + connect(); + if (is_interactive) { std::cout << std::endl; @@ -563,6 +565,12 @@ void LocalServer::processConfig() if (config().has("macros")) global_context->setMacros(std::make_unique(config(), "macros", log)); + is_default_format = !config().has("vertical") && !config().has("format"); + if (config().has("vertical")) + format = config().getString("format", "Vertical"); + else + format = config().getString("format", is_interactive ? "PrettyCompact" : "TabSeparated"); + /// Skip networking /// Sets external authenticators config (LDAP, Kerberos). @@ -631,6 +639,9 @@ void LocalServer::processConfig() std::map prompt_substitutions{{"display_name", server_display_name}}; for (const auto & [key, value] : prompt_substitutions) boost::replace_all(prompt_by_server_display_name, "{" + key + "}", value); + + ClientInfo & client_info = global_context->getClientInfo(); + client_info.setInitialQuery(); } diff --git a/programs/local/LocalServer.h b/programs/local/LocalServer.h index e8880971ee2..335cbdb2a3e 100644 --- a/programs/local/LocalServer.h +++ b/programs/local/LocalServer.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -34,6 +35,12 @@ public: } protected: + void connect() override + { + connection_parameters = ConnectionParameters(config()); + connection = std::make_unique(global_context); + } + void processSingleQuery(const String & full_query) override; bool processMultiQuery(const String & all_queries_text) override; diff --git a/src/Client/ClientBase.cpp b/src/Client/ClientBase.cpp index c5538895d36..ebed3d18a85 100644 --- a/src/Client/ClientBase.cpp +++ b/src/Client/ClientBase.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include @@ -36,10 +37,13 @@ #include #include +#include #include -#include #include +#include +#include + namespace fs = std::filesystem; @@ -51,6 +55,9 @@ static const NameSet exit_strings{"exit", "quit", "logout", "учше", "йгш namespace ErrorCodes { extern const int BAD_ARGUMENTS; + extern const int DEADLOCK_AVOIDED; + extern const int CLIENT_OUTPUT_FORMAT_SPECIFIED; + extern const int UNKNOWN_PACKET_FROM_SERVER; } } @@ -144,6 +151,371 @@ static void adjustQueryEnd(const char *& this_query_end, const char * all_querie } +/// Convert external tables to ExternalTableData and send them using the connection. +void ClientBase::sendExternalTables(ASTPtr parsed_query) +{ + const auto * select = parsed_query->as(); + if (!select && !external_tables.empty()) + throw Exception("External tables could be sent only with select query", ErrorCodes::BAD_ARGUMENTS); + + std::vector data; + for (auto & table : external_tables) + data.emplace_back(table.getData(global_context)); + + connection->sendExternalTablesData(data); +} + + +void ClientBase::onData(Block & block, ASTPtr parsed_query) +{ + if (!block) + return; + + processed_rows += block.rows(); + + /// Even if all blocks are empty, we still need to initialize the output stream to write empty resultset. + initBlockOutputStream(block, parsed_query); + + /// The header block containing zero rows was used to initialize + /// block_out_stream, do not output it. + /// Also do not output too much data if we're fuzzing. + if (block.rows() == 0 || (query_fuzzer_runs != 0 && processed_rows >= 100)) + return; + + if (need_render_progress) + progress_indication.clearProgressOutput(); + + block_out_stream->write(block); + written_first_block = true; + + /// Received data block is immediately displayed to the user. + block_out_stream->flush(); + + /// Restore progress bar after data block. + if (need_render_progress) + progress_indication.writeProgress(); +} + + +void ClientBase::onLogData(Block & block) +{ + initLogsOutputStream(); + progress_indication.clearProgressOutput(); + logs_out_stream->write(block); + logs_out_stream->flush(); +} + + +void ClientBase::onTotals(Block & block, ASTPtr parsed_query) +{ + initBlockOutputStream(block, parsed_query); + block_out_stream->setTotals(block); +} + + +void ClientBase::onExtremes(Block & block, ASTPtr parsed_query) +{ + initBlockOutputStream(block, parsed_query); + block_out_stream->setExtremes(block); +} + + +void ClientBase::onReceiveExceptionFromServer(std::unique_ptr && e) +{ + have_error = true; + server_exception = std::move(e); + resetOutput(); +} + + +void ClientBase::onProfileInfo(const BlockStreamProfileInfo & profile_info) +{ + if (profile_info.hasAppliedLimit() && block_out_stream) + block_out_stream->setRowsBeforeLimit(profile_info.getRowsBeforeLimit()); +} + + +void ClientBase::initBlockOutputStream(const Block & block, ASTPtr parsed_query) +{ + if (!block_out_stream) + { + /// Ignore all results when fuzzing as they can be huge. + if (query_fuzzer_runs) + { + block_out_stream = std::make_shared(block); + return; + } + + WriteBuffer * out_buf = nullptr; + String pager = config().getString("pager", ""); + if (!pager.empty()) + { + signal(SIGPIPE, SIG_IGN); + pager_cmd = ShellCommand::execute(pager, true); + out_buf = &pager_cmd->in; + } + else + { + out_buf = &std_out; + } + + String current_format = format; + + /// The query can specify output format or output file. + /// FIXME: try to prettify this cast using `as<>()` + if (const auto * query_with_output = dynamic_cast(parsed_query.get())) + { + if (query_with_output->out_file) + { + const auto & out_file_node = query_with_output->out_file->as(); + const auto & out_file = out_file_node.value.safeGet(); + + out_file_buf = wrapWriteBufferWithCompressionMethod( + std::make_unique(out_file, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_EXCL | O_CREAT), + chooseCompressionMethod(out_file, ""), + /* compression level = */ 3 + ); + + // We are writing to file, so default format is the same as in non-interactive mode. + if (is_interactive && is_default_format) + current_format = "TabSeparated"; + } + if (query_with_output->format != nullptr) + { + if (has_vertical_output_suffix) + throw Exception("Output format already specified", ErrorCodes::CLIENT_OUTPUT_FORMAT_SPECIFIED); + const auto & id = query_with_output->format->as(); + current_format = id.name(); + } + } + + if (has_vertical_output_suffix) + current_format = "Vertical"; + + /// It is not clear how to write progress with parallel formatting. It may increase code complexity significantly. + if (!need_render_progress) + block_out_stream = global_context->getOutputStreamParallelIfPossible(current_format, out_file_buf ? *out_file_buf : *out_buf, block); + else + block_out_stream = global_context->getOutputStream(current_format, out_file_buf ? *out_file_buf : *out_buf, block); + + block_out_stream->writePrefix(); + } +} + + +void ClientBase::initLogsOutputStream() +{ + if (!logs_out_stream) + { + WriteBuffer * wb = out_logs_buf.get(); + + if (!out_logs_buf) + { + if (server_logs_file.empty()) + { + /// Use stderr by default + out_logs_buf = std::make_unique(STDERR_FILENO); + wb = out_logs_buf.get(); + } + else if (server_logs_file == "-") + { + /// Use stdout if --server_logs_file=- specified + wb = &std_out; + } + else + { + out_logs_buf + = std::make_unique(server_logs_file, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_APPEND | O_CREAT); + wb = out_logs_buf.get(); + } + } + + logs_out_stream = std::make_shared(*wb, stdout_is_a_tty); + logs_out_stream->writePrefix(); + } +} + + +void ClientBase::processOrdinaryQuery(const String & query_to_execute, ASTPtr parsed_query) +{ + /// Rewrite query only when we have query parameters. + /// Note that if query is rewritten, comments in query are lost. + /// But the user often wants to see comments in server logs, query log, processlist, etc. + auto query = query_to_execute; + if (!query_parameters.empty()) + { + /// Replace ASTQueryParameter with ASTLiteral for prepared statements. + ReplaceQueryParameterVisitor visitor(query_parameters); + visitor.visit(parsed_query); + + /// Get new query after substitutions. Note that it cannot be done for INSERT query with embedded data. + query = serializeAST(*parsed_query); + } + + int retries_left = 10; + for (;;) + { + assert(retries_left > 0); + + try + { + connection->sendQuery( + connection_parameters.timeouts, + query, + global_context->getCurrentQueryId(), + query_processing_stage, + &global_context->getSettingsRef(), + &global_context->getClientInfo(), + true); + + sendExternalTables(parsed_query); + receiveResult(parsed_query); + + break; + } + catch (const Exception & e) + { + /// Retry when the server said "Client should retry" and no rows + /// has been received yet. + if (processed_rows == 0 && e.code() == ErrorCodes::DEADLOCK_AVOIDED && --retries_left) + { + std::cerr << "Got a transient error from the server, will" + << " retry (" << retries_left << " retries left)"; + } + else + { + throw; + } + } + } +} + + +/// Receives and processes packets coming from server. +/// Also checks if query execution should be cancelled. +void ClientBase::receiveResult(ASTPtr parsed_query) +{ + InterruptListener interrupt_listener; + bool cancelled = false; + + // TODO: get the poll_interval from commandline. + const auto receive_timeout = connection_parameters.timeouts.receive_timeout; + constexpr size_t default_poll_interval = 1000000; /// in microseconds + constexpr size_t min_poll_interval = 5000; /// in microseconds + const size_t poll_interval + = std::max(min_poll_interval, std::min(receive_timeout.totalMicroseconds(), default_poll_interval)); + + while (true) + { + Stopwatch receive_watch(CLOCK_MONOTONIC_COARSE); + + while (true) + { + /// Has the Ctrl+C been pressed and thus the query should be cancelled? + /// If this is the case, inform the server about it and receive the remaining packets + /// to avoid losing sync. + if (!cancelled) + { + auto cancel_query = [&] { + connection->sendCancel(); + cancelled = true; + if (is_interactive) + { + progress_indication.clearProgressOutput(); + std::cout << "Cancelling query." << std::endl; + } + + /// Pressing Ctrl+C twice results in shut down. + interrupt_listener.unblock(); + }; + + if (interrupt_listener.check()) + { + cancel_query(); + } + else + { + double elapsed = receive_watch.elapsedSeconds(); + if (elapsed > receive_timeout.totalSeconds()) + { + std::cout << "Timeout exceeded while receiving data from server." + << " Waited for " << static_cast(elapsed) << " seconds," + << " timeout is " << receive_timeout.totalSeconds() << " seconds." << std::endl; + + cancel_query(); + } + } + } + + /// Poll for changes after a cancellation check, otherwise it never reached + /// because of progress updates from server. + if (connection->poll(poll_interval)) + break; + } + + if (!receiveAndProcessPacket(parsed_query, cancelled)) + break; + } + + if (cancelled && is_interactive) + std::cout << "Query was cancelled." << std::endl; +} + + +/// Receive a part of the result, or progress info or an exception and process it. +/// Returns true if one should continue receiving packets. +/// Output of result is suppressed if query was cancelled. +bool ClientBase::receiveAndProcessPacket(ASTPtr parsed_query, bool cancelled) +{ + Packet packet = connection->receivePacket(); + + switch (packet.type) + { + case Protocol::Server::PartUUIDs: + return true; + + case Protocol::Server::Data: + if (!cancelled) + onData(packet.block, parsed_query); + return true; + + case Protocol::Server::Progress: + onProgress(packet.progress); + return true; + + case Protocol::Server::ProfileInfo: + onProfileInfo(packet.profile_info); + return true; + + case Protocol::Server::Totals: + if (!cancelled) + onTotals(packet.block, parsed_query); + return true; + + case Protocol::Server::Extremes: + if (!cancelled) + onExtremes(packet.block, parsed_query); + return true; + + case Protocol::Server::Exception: + onReceiveExceptionFromServer(std::move(packet.exception)); + return false; + + case Protocol::Server::Log: + onLogData(packet.block); + return true; + + case Protocol::Server::EndOfStream: + onEndOfStream(); + return false; + + default: + throw Exception( + ErrorCodes::UNKNOWN_PACKET_FROM_SERVER, "Unknown packet {} from server {}", packet.type, connection->getDescription()); + } +} + + void ClientBase::onProgress(const Progress & value) { if (!progress_indication.updateProgress(value)) @@ -221,10 +593,10 @@ void ClientBase::processSingleQueryImpl(const String & full_query, const String std_out.next(); } + global_context->setCurrentQueryId(""); if (is_interactive) { // Generate a new query_id - global_context->setCurrentQueryId(""); for (const auto & query_id_format : query_id_formats) { writeString(query_id_format.first, std_out); diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index 02e926632c6..ee6a757a9c2 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace po = boost::program_options; @@ -34,6 +35,14 @@ public: int main(const std::vector & /*args*/) override; protected: + void processOrdinaryQuery(const String & query_to_execute, ASTPtr parsed_query); + void receiveResult(ASTPtr parsed_query); + bool receiveAndProcessPacket(ASTPtr parsed_query, bool cancelled); + void initBlockOutputStream(const Block & block, ASTPtr parsed_query); + void initLogsOutputStream(); + void sendExternalTables(ASTPtr parsed_query); + virtual void connect() = 0; + /* * Run interactive or non-interactive mode. Depends on: * - processSingleQuery @@ -97,6 +106,12 @@ protected: virtual void loadSuggestionData(Suggest &) = 0; + void onData(Block & block, ASTPtr parsed_query); + void onLogData(Block & block); + void onTotals(Block & block, ASTPtr parsed_query); + void onExtremes(Block & block, ASTPtr parsed_query); + void onReceiveExceptionFromServer(std::unique_ptr && e); + void onProfileInfo(const BlockStreamProfileInfo & profile_info); void resetOutput(); @@ -200,6 +215,27 @@ protected: /// We will format query_id in interactive mode in various ways, the default is just to print Query id: ... std::vector> query_id_formats; + + /// Dictionary with query parameters for prepared statements. + NameToNameMap query_parameters; + + std::unique_ptr connection; + ConnectionParameters connection_parameters; + + String format; /// Query results output format. + bool is_default_format = true; /// false, if format is set in the config or command line. + + /// The last exception that was received from the server. Is used for the + /// return code in batch mode. + std::unique_ptr server_exception; + /// Likewise, the last exception that occurred on the client. + std::unique_ptr client_exception; + + QueryProcessingStage::Enum query_processing_stage; + + /// External tables info. + std::list external_tables; + }; } diff --git a/src/Client/Connection.cpp b/src/Client/Connection.cpp index 366e61bc8e2..d53247ef1ea 100644 --- a/src/Client/Connection.cpp +++ b/src/Client/Connection.cpp @@ -499,7 +499,7 @@ void Connection::sendQuery( /// Send empty block which means end of data. if (!with_pending_data) { - sendData(Block()); + sendData(Block(), /* name */"", /* scalar */false); out->next(); } } @@ -654,7 +654,7 @@ protected: num_rows += chunk.getNumRows(); auto block = getPort().getHeader().cloneWithColumns(chunk.detachColumns()); - connection.sendData(block, table_data.table_name); + connection.sendData(block, table_data.table_name, /* scalar */false); } private: @@ -670,7 +670,7 @@ void Connection::sendExternalTablesData(ExternalTablesData & data) if (data.empty()) { /// Send empty block, which means end of data transfer. - sendData(Block()); + sendData(Block(), "", false); return; } @@ -702,17 +702,16 @@ void Connection::sendExternalTablesData(ExternalTablesData & data) }); executor = pipeline.execute(); executor->execute(/*num_threads = */ 1); - - auto read_rows = sink->getNumReadRows(); +auto read_rows = sink->getNumReadRows(); rows += read_rows; /// If table is empty, send empty block with name. if (read_rows == 0) - sendData(sink->getPort().getHeader(), elem->table_name); + sendData(sink->getPort().getHeader(), elem->table_name, /* scalar */false); } /// Send empty block, which means end of data transfer. - sendData(Block()); + sendData(Block(), /* name */"", /* scalar */false); out_bytes = out->count() - out_bytes; maybe_compressed_out_bytes = maybe_compressed_out->count() - maybe_compressed_out_bytes; diff --git a/src/Client/Connection.h b/src/Client/Connection.h index 2ea5a236a13..9448ed36d96 100644 --- a/src/Client/Connection.h +++ b/src/Client/Connection.h @@ -4,20 +4,13 @@ #include -#include #if !defined(ARCADIA_BUILD) # include #endif -#include +#include #include -#include -#include -#include -#include -#include -#include #include #include @@ -31,46 +24,14 @@ namespace DB { -class ClientInfo; -class Pipe; struct Settings; -/// Struct which represents data we are going to send for external table. -struct ExternalTableData -{ - /// Pipe of data form table; - std::unique_ptr pipe; - std::string table_name; - std::function()> creating_pipe_callback; - /// Flag if need to stop reading. - std::atomic_bool is_cancelled = false; -}; - -using ExternalTableDataPtr = std::unique_ptr; -using ExternalTablesData = std::vector; - class Connection; using ConnectionPtr = std::shared_ptr; using Connections = std::vector; -/// Packet that could be received from server. -struct Packet -{ - UInt64 type; - - Block block; - std::unique_ptr exception; - std::vector multistring_message; - Progress progress; - BlockStreamProfileInfo profile_info; - std::vector part_uuids; - - Packet() : type(Protocol::Server::Hello) {} -}; - - /** Connection with database server, to use by client. * How to use - see Core/Protocol.h * (Implementation of server end - see Server/TCPHandler.h) @@ -78,7 +39,7 @@ struct Packet * As 'default_database' empty string could be passed * - in that case, server will use it's own default database. */ -class Connection : private boost::noncopyable +class Connection : public IServerConnection { friend class MultiplexedConnections; @@ -111,92 +72,77 @@ public: setDescription(); } - virtual ~Connection() = default; - /// Set throttler of network traffic. One throttler could be used for multiple connections to limit total traffic. - void setThrottler(const ThrottlerPtr & throttler_) + void setThrottler(const ThrottlerPtr & throttler_) override { throttler = throttler_; } - /// Change default database. Changes will take effect on next reconnect. - void setDefaultDatabase(const String & database); + void setDefaultDatabase(const String & database) override; void getServerVersion(const ConnectionTimeouts & timeouts, String & name, UInt64 & version_major, UInt64 & version_minor, UInt64 & version_patch, - UInt64 & revision); - UInt64 getServerRevision(const ConnectionTimeouts & timeouts); + UInt64 & revision) override; - const String & getServerTimezone(const ConnectionTimeouts & timeouts); - const String & getServerDisplayName(const ConnectionTimeouts & timeouts); + UInt64 getServerRevision(const ConnectionTimeouts & timeouts) override; + + const String & getServerTimezone(const ConnectionTimeouts & timeouts) override; + const String & getServerDisplayName(const ConnectionTimeouts & timeouts) override; /// For log and exception messages. - const String & getDescription() const; + const String & getDescription() const override; const String & getHost() const; UInt16 getPort() const; const String & getDefaultDatabase() const; Protocol::Compression getCompression() const { return compression; } - /// If last flag is true, you need to call sendExternalTablesData after. void sendQuery( - const ConnectionTimeouts & timeouts, - const String & query, - const String & query_id_ = "", - UInt64 stage = QueryProcessingStage::Complete, - const Settings * settings = nullptr, - const ClientInfo * client_info = nullptr, - bool with_pending_data = false); + const ConnectionTimeouts & timeouts, const String & query, + const String & query_id_, UInt64 stage, + const Settings * settings, const ClientInfo * client_info, + bool with_pending_database) override; - void sendCancel(); - /// Send block of data; if name is specified, server will write it to external (temporary) table of that name. - void sendData(const Block & block, const String & name = "", bool scalar = false); - /// Send all scalars. - void sendScalarsData(Scalars & data); - /// Send all contents of external (temporary) tables. - void sendExternalTablesData(ExternalTablesData & data); - /// Send parts' uuids to excluded them from query processing - void sendIgnoredPartUUIDs(const std::vector & uuids); + void sendCancel() override; + + void sendData(const Block & block, const String & name, bool scalar) override; + + void sendExternalTablesData(ExternalTablesData & data) override; + + bool poll(size_t timeout_microseconds) override; + + bool hasReadPendingData() const override; + + std::optional checkPacket(size_t timeout_microseconds) override; + + Packet receivePacket() override; + + void forceConnected(const ConnectionTimeouts & timeouts) override; + + bool isConnected() const override { return connected; } + + bool checkConnected() override { return connected && ping(); } + + void disconnect() override; - void sendReadTaskResponse(const String &); /// Send prepared block of data (serialized and, if need, compressed), that will be read from 'input'. /// You could pass size of serialized/compressed block. void sendPreparedData(ReadBuffer & input, size_t size, const String & name = ""); - /// Check, if has data to read. - bool poll(size_t timeout_microseconds = 0); - - /// Check, if has data in read buffer. - bool hasReadPendingData() const; - - /// Checks if there is input data in connection and reads packet ID. - std::optional checkPacket(size_t timeout_microseconds = 0); - - /// Receive packet from server. - Packet receivePacket(); - - /// If not connected yet, or if connection is broken - then connect. If cannot connect - throw an exception. - void forceConnected(const ConnectionTimeouts & timeouts); - - bool isConnected() const { return connected; } - - /// Check if connection is still active with ping request. - bool checkConnected() { return connected && ping(); } + void sendReadTaskResponse(const String &); + /// Send all scalars. + void sendScalarsData(Scalars & data); + /// Send parts' uuids to excluded them from query processing + void sendIgnoredPartUUIDs(const std::vector & uuids); TablesStatusResponse getTablesStatus(const ConnectionTimeouts & timeouts, const TablesStatusRequest & request); - /** Disconnect. - * This may be used, if connection is left in unsynchronised state - * (when someone continues to wait for something) after an exception. - */ - void disconnect(); - size_t outBytesCount() const { return out ? out->count() : 0; } size_t inBytesCount() const { return in ? in->count() : 0; } @@ -209,7 +155,6 @@ public: if (in) in->setAsyncCallback(std::move(async_callback)); } - private: String host; UInt16 port; diff --git a/src/Client/IServerConnection.h b/src/Client/IServerConnection.h new file mode 100644 index 00000000000..fc96bc16294 --- /dev/null +++ b/src/Client/IServerConnection.h @@ -0,0 +1,124 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + + +#include + + +namespace DB +{ + +class ClientInfo; + +/// Packet that could be received from server. +struct Packet +{ + UInt64 type; + + Block block; + std::unique_ptr exception; + std::vector multistring_message; + Progress progress; + BlockStreamProfileInfo profile_info; + std::vector part_uuids; + + Packet() : type(Protocol::Server::Hello) {} +}; + +/// Struct which represents data we are going to send for external table. +struct ExternalTableData +{ + /// Pipe of data form table; + std::unique_ptr pipe; + std::string table_name; + std::function()> creating_pipe_callback; + /// Flag if need to stop reading. + std::atomic_bool is_cancelled = false; +}; + +using ExternalTableDataPtr = std::unique_ptr; +using ExternalTablesData = std::vector; + + +class IServerConnection : boost::noncopyable +{ +public: + virtual ~IServerConnection() = default; + + virtual void setDefaultDatabase(const String & database) = 0; + + virtual void getServerVersion( + const ConnectionTimeouts & timeouts, String & name, + UInt64 & version_major, UInt64 & version_minor, + UInt64 & version_patch, UInt64 & revision) = 0; + + virtual UInt64 getServerRevision(const ConnectionTimeouts & timeouts) = 0; + + virtual const String & getServerTimezone(const ConnectionTimeouts & timeouts) = 0; + virtual const String & getServerDisplayName(const ConnectionTimeouts & timeouts) = 0; + + virtual const String & getDescription() const = 0; + + /// If last flag is true, you need to call sendExternalTablesData after. + virtual void sendQuery( + const ConnectionTimeouts & timeouts, + const String & query, + const String & query_id_ /* = "" */, + UInt64 stage/* = QueryProcessingStage::Complete */, + const Settings * settings /* = nullptr */, + const ClientInfo * client_info /* = nullptr */, + bool with_pending_data /* = false */) = 0; + + virtual void sendCancel() = 0; + + /// Send block of data; if name is specified, server will write it to external (temporary) table of that name. + virtual void sendData(const Block & block, const String & name/* = "" */, bool scalar/* = false */) = 0; + + /// Send all contents of external (temporary) tables. + virtual void sendExternalTablesData(ExternalTablesData & data) = 0; + + /// Check, if has data to read. + virtual bool poll(size_t timeout_microseconds /* = 0 */) = 0; + + /// Check, if has data in read buffer. + virtual bool hasReadPendingData() const = 0; + + /// Checks if there is input data in connection and reads packet ID. + virtual std::optional checkPacket(size_t timeout_microseconds /* = 0 */) = 0; + + /// Receive packet from server. + virtual Packet receivePacket() = 0; + + /// If not connected yet, or if connection is broken - then connect. If cannot connect - throw an exception. + virtual void forceConnected(const ConnectionTimeouts & timeouts) = 0; + + virtual bool isConnected() const = 0; + + /// Check if connection is still active with ping request. + virtual bool checkConnected() = 0; + + /** Disconnect. + * This may be used, if connection is left in unsynchronised state + * (when someone continues to wait for something) after an exception. + */ + virtual void disconnect() = 0; + + /// Set throttler of network traffic. One throttler could be used for multiple connections to limit total traffic. + virtual void setThrottler(const ThrottlerPtr & throttler_) = 0; +}; + +using ServerConnection = std::unique_ptr; + +} diff --git a/src/Client/LocalConnection.cpp b/src/Client/LocalConnection.cpp new file mode 100644 index 00000000000..4d734b4d523 --- /dev/null +++ b/src/Client/LocalConnection.cpp @@ -0,0 +1,215 @@ +#include "LocalConnection.h" +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_PACKET_FROM_SERVER; +} + +LocalConnection::LocalConnection(ContextPtr context_) + : WithContext(context_) +{ +} + +void LocalConnection::setDefaultDatabase(const String & database) +{ + default_database = database; +} + +void LocalConnection::getServerVersion( + const ConnectionTimeouts & /* timeouts */, String & name, + UInt64 & version_major, UInt64 & version_minor, + UInt64 & version_patch, UInt64 & revision) +{ + name = server_name; + version_major = server_version_major; + version_minor = server_version_minor; + version_patch = server_version_patch; + revision = server_revision; +} + +UInt64 LocalConnection::getServerRevision(const ConnectionTimeouts &) +{ + return server_revision; +} + +const String & LocalConnection::getDescription() const +{ + return description; +} + +const String & LocalConnection::getServerTimezone(const ConnectionTimeouts &) +{ + return server_timezone; +} + +const String & LocalConnection::getServerDisplayName(const ConnectionTimeouts &) +{ + return server_display_name; +} + +/* + * SendQuery: execute query and suspend the result, which will be received back via poll. +**/ +void LocalConnection::sendQuery( + const ConnectionTimeouts &, + const String & query_, + const String & query_id_, + UInt64, + const Settings *, + const ClientInfo *, + bool) +{ + query_context = Context::createCopy(getContext()); + query_context->makeQueryContext(); + query_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); }); + /// query_context->setCurrentDatabase(default_database); + + state.query_id = query_id_; + state.query = query_; + + state.io = executeQuery(state.query, query_context, false, state.stage, false); + if (state.io.out) + { + state.need_receive_data_for_insert = true; + /// processInsertQuery(); + } + else if (state.io.pipeline.initialized()) + { + state.executor = std::make_unique(state.io.pipeline); + } + else if (state.io.in) + { + state.async_in = std::make_unique(state.io.in); + state.async_in->readPrefix(); + } +} + +void LocalConnection::sendCancel() +{ + if (state.async_in) + { + state.async_in->cancel(false); + } + else if (state.executor) + { + state.executor->cancel(); + } +} + +Block LocalConnection::pullBlock() +{ + Block block; + if (state.async_in) + { + if (state.async_in->poll(query_context->getSettingsRef().interactive_delay / 1000)) + return state.async_in->read(); + } + else if (state.executor) + { + state.executor->pull(block, query_context->getSettingsRef().interactive_delay / 1000); + } + return block; +} + +void LocalConnection::finishQuery() +{ + if (state.async_in) + { + state.async_in->readSuffix(); + state.async_in.reset(); + } + else if (state.executor) + { + state.executor.reset(); + } + + // sendProgress(); + state.io.onFinish(); + query_context.reset(); +} + +bool LocalConnection::poll(size_t) +{ + if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay) + { + after_send_progress.restart(); + next_packet_type = Protocol::Server::Progress; + + return true; + } + + auto block = pullBlock(); + if (block) + { + next_packet_type = Protocol::Server::Data; + + if (state.io.null_format) + state.block.emplace(); + else + state.block.emplace(block); + } + else + { + state.is_finished = true; + next_packet_type = Protocol::Server::EndOfStream; + } + return true; +} + +Packet LocalConnection::receivePacket() +{ + Packet packet; + + packet.type = next_packet_type.value(); + switch (next_packet_type.value()) + { + case Protocol::Server::Data: + { + if (state.block) + { + packet.block = std::move(*state.block); + state.block.reset(); + } + + break; + } + case Protocol::Server::Progress: + { + packet.progress = std::move(state.progress); + state.progress.reset(); + break; + } + case Protocol::Server::EndOfStream: + { + finishQuery(); + break; + } + default: + throw Exception("Unknown packet " + toString(packet.type) + + " from server " + getDescription(), ErrorCodes::UNKNOWN_PACKET_FROM_SERVER); + } + return packet; +} + +bool LocalConnection::hasReadPendingData() const +{ + return !state.is_finished; +} + +std::optional LocalConnection::checkPacket(size_t) +{ + return next_packet_type; +} + +void LocalConnection::updateProgress(const Progress & value) +{ + state.progress.incrementPiecewiseAtomically(value); +} + +} diff --git a/src/Client/LocalConnection.h b/src/Client/LocalConnection.h new file mode 100644 index 00000000000..ed01a62ed0e --- /dev/null +++ b/src/Client/LocalConnection.h @@ -0,0 +1,160 @@ +#include "Connection.h" +#include +#include +#include +#include +#include + + +namespace DB +{ + +/// State of query processing. +struct LocalQueryState +{ + /// Identifier of the query. + String query_id; + + QueryProcessingStage::Enum stage = QueryProcessingStage::Complete; + + /// A queue with internal logs that will be passed to client. It must be + /// destroyed after input/output blocks, because they may contain other + /// threads that use this queue. + InternalTextLogsQueuePtr logs_queue; + BlockOutputStreamPtr logs_block_out; + + /// Query text. + String query; + /// Streams of blocks, that are processing the query. + BlockIO io; + /// Current stream to pull blocks from. + std::unique_ptr async_in; + std::unique_ptr executor; + + /// Last polled block. + std::optional block; + + /// Is request cancelled + bool is_cancelled = false; + /// Is query finished == !has_pending_data + bool is_finished = false; + /// empty or not + bool is_empty = true; + /// Data was sent. + bool sent_all_data = false; + /// Request requires data from the client (INSERT, but not INSERT SELECT). + bool need_receive_data_for_insert = false; + /// Temporary tables read + bool temporary_tables_read = false; + + /// A state got uuids to exclude from a query + bool part_uuids = false; + + /// Request requires data from client for function input() + bool need_receive_data_for_input = false; + /// temporary place for incoming data block for input() + Block block_for_input; + /// sample block from StorageInput + Block input_header; + + /// To output progress, the difference after the previous sending of progress. + Progress progress; + + /// Timeouts setter for current query + std::unique_ptr timeout_setter; +}; + + +class LocalConnection : public IServerConnection, WithContext +{ +public: + explicit LocalConnection(ContextPtr context_); + + void setDefaultDatabase(const String & database) override; + + void getServerVersion(const ConnectionTimeouts & timeouts, + String & name, + UInt64 & version_major, + UInt64 & version_minor, + UInt64 & version_patch, + UInt64 & revision) override; + + UInt64 getServerRevision(const ConnectionTimeouts & timeouts) override; + + const String & getServerTimezone(const ConnectionTimeouts & timeouts) override; + const String & getServerDisplayName(const ConnectionTimeouts & timeouts) override; + + const String & getDescription() const override; + + void sendQuery( + const ConnectionTimeouts & timeouts, + const String & query, + const String & query_id_ /* = "" */, + UInt64 stage/* = QueryProcessingStage::Complete */, + const Settings * settings /* = nullptr */, + const ClientInfo * client_info /* = nullptr */, + bool with_pending_data /* = false */) override; + + void sendCancel() override; + + void sendData(const Block &, const String &, bool) override {} + + void sendExternalTablesData(ExternalTablesData &) override {} + + bool poll(size_t timeout_microseconds) override; + + bool hasReadPendingData() const override; + + std::optional checkPacket(size_t timeout_microseconds) override; + + Packet receivePacket() override; + + void forceConnected(const ConnectionTimeouts &) override {} + + bool isConnected() const override { return true; } + + bool checkConnected() override { return true; } + + void disconnect() override {} + + void setThrottler(const ThrottlerPtr &) override {} + +private: + ContextMutablePtr query_context; + + String description; + + String server_name; + UInt64 server_version_major = 0; + UInt64 server_version_minor = 0; + UInt64 server_version_patch = 0; + UInt64 server_revision = 0; + String server_timezone; + String server_display_name; + String default_database; + + /// At the moment, only one ongoing query in the connection is supported at a time. + LocalQueryState state; + + /// Last "server" packet. + std::optional next_packet_type; + + /// Time after the last check to stop the request and send the progress. + Stopwatch after_check_cancelled; + Stopwatch after_send_progress; + + void initBlockInput(); + + void processOrdinaryQuery(); + + void processOrdinaryQueryWithProcessors(); + + void updateState(); + + Block pullBlock(); + + void finishQuery(); + + void updateProgress(const Progress & value); +}; +} diff --git a/src/Client/Suggest.cpp b/src/Client/Suggest.cpp index 806c4a2eb22..57e57a70a0b 100644 --- a/src/Client/Suggest.cpp +++ b/src/Client/Suggest.cpp @@ -197,7 +197,7 @@ void Suggest::loadImpl(Connection & connection, const ConnectionTimeouts & timeo void Suggest::fetch(Connection & connection, const ConnectionTimeouts & timeouts, const std::string & query) { - connection.sendQuery(timeouts, query, "" /* query_id */, QueryProcessingStage::Complete); + connection.sendQuery(timeouts, query, "" /* query_id */, QueryProcessingStage::Complete, nullptr, nullptr, false); while (true) { diff --git a/src/DataStreams/RemoteBlockOutputStream.cpp b/src/DataStreams/RemoteBlockOutputStream.cpp index 976c4671652..c0fcf9f1021 100644 --- a/src/DataStreams/RemoteBlockOutputStream.cpp +++ b/src/DataStreams/RemoteBlockOutputStream.cpp @@ -36,7 +36,7 @@ RemoteBlockOutputStream::RemoteBlockOutputStream(Connection & connection_, /** Send query and receive "header", that describes table structure. * Header is needed to know, what structure is required for blocks to be passed to 'write' method. */ - connection.sendQuery(timeouts, query, "", QueryProcessingStage::Complete, &settings_, &modified_client_info); + connection.sendQuery(timeouts, query, "", QueryProcessingStage::Complete, &settings_, &modified_client_info, false); while (true) { @@ -77,12 +77,12 @@ void RemoteBlockOutputStream::write(const Block & block) try { - connection.sendData(block); + connection.sendData(block, /* name */"", /* scalar */false); } catch (const NetException &) { /// Try to get more detailed exception from server - auto packet_type = connection.checkPacket(); + auto packet_type = connection.checkPacket(/* timeout_microseconds */0); if (packet_type && *packet_type == Protocol::Server::Exception) { Packet packet = connection.receivePacket(); @@ -104,7 +104,7 @@ void RemoteBlockOutputStream::writePrepared(ReadBuffer & input, size_t size) void RemoteBlockOutputStream::writeSuffix() { /// Empty block means end of data. - connection.sendData(Block()); + connection.sendData(Block(), /* name */"", /* scalar */false); /// Wait for EndOfStream or Exception packet, skip Log packets. while (true)