From b0cb3eb306d5542e24015b8ffb989d9b89b7eba9 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sun, 25 Oct 2020 01:03:49 +0300 Subject: [PATCH] Extend the protocol with streaming and nonstreaming functions. --- src/Server/GRPCServer.cpp | 292 +++++++++++++++---- src/Server/grpc_protos/clickhouse_grpc.proto | 5 +- tests/integration/test_grpc_protocol/test.py | 22 +- 3 files changed, 261 insertions(+), 58 deletions(-) diff --git a/src/Server/GRPCServer.cpp b/src/Server/GRPCServer.cpp index d985a0d4df7..10ae6e873a4 100644 --- a/src/Server/GRPCServer.cpp +++ b/src/Server/GRPCServer.cpp @@ -39,6 +39,7 @@ namespace ErrorCodes { extern const int INVALID_GRPC_QUERY_INFO; extern const int INVALID_SESSION_TIMEOUT; + extern const int LOGICAL_ERROR; extern const int NETWORK_ERROR; extern const int NO_DATA_TO_INSERT; extern const int UNKNOWN_DATABASE; @@ -112,40 +113,23 @@ namespace using CompletionCallback = std::function; /// Requests a connection and provides low-level interface for reading and writing. - class Responder + class BaseResponder { public: - void start( - GRPCService & grpc_service, - grpc::ServerCompletionQueue & new_call_queue, - grpc::ServerCompletionQueue & notification_queue, - const CompletionCallback & callback) - { - grpc_service.RequestExecuteQuery(&grpc_context, &reader_writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); - } + virtual ~BaseResponder() = default; - void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) - { - reader_writer.Read(&query_info_, getCallbackPtr(callback)); - } + virtual void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) = 0; - void write(const GRPCResult & result, const CompletionCallback & callback) - { - reader_writer.Write(result, getCallbackPtr(callback)); - } + virtual void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) = 0; + virtual void write(const GRPCResult & result, const CompletionCallback & callback) = 0; + virtual void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) = 0; - void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) - { - reader_writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback)); - } + Poco::Net::SocketAddress getClientAddress() const { String peer = grpc_context.peer(); return Poco::Net::SocketAddress{peer.substr(peer.find(':') + 1)}; } - Poco::Net::SocketAddress getClientAddress() const - { - String peer = grpc_context.peer(); - return Poco::Net::SocketAddress{peer.substr(peer.find(':') + 1)}; - } - - private: + protected: CompletionCallback * getCallbackPtr(const CompletionCallback & callback) { /// It would be better to pass callbacks to gRPC calls. @@ -166,13 +150,198 @@ namespace }; return &callback_in_map; } + grpc::ServerContext grpc_context; + + private: grpc::ServerAsyncReaderWriter reader_writer{&grpc_context}; std::unordered_map callbacks; size_t next_callback_id = 0; std::mutex mutex; }; + enum CallType + { + CALL_SIMPLE, /// ExecuteQuery() call + CALL_WITH_STREAM_INPUT, /// ExecuteQueryWithStreamInput() call + CALL_WITH_STREAM_OUTPUT, /// ExecuteQueryWithStreamOutput() call + CALL_WITH_STREAM_IO, /// ExecuteQueryWithStreamIO() call + CALL_MAX, + }; + + const char * getCallName(CallType call_type) + { + switch (call_type) + { + case CALL_SIMPLE: return "ExecuteQuery()"; + case CALL_WITH_STREAM_INPUT: return "ExecuteQueryWithStreamInput()"; + case CALL_WITH_STREAM_OUTPUT: return "ExecuteQueryWithStreamOutput()"; + case CALL_WITH_STREAM_IO: return "ExecuteQueryWithStreamIO()"; + case CALL_MAX: break; + } + __builtin_unreachable(); + } + + bool isInputStreaming(CallType call_type) + { + return (call_type == CALL_WITH_STREAM_INPUT) || (call_type == CALL_WITH_STREAM_IO); + } + + bool isOutputStreaming(CallType call_type) + { + return (call_type == CALL_WITH_STREAM_OUTPUT) || (call_type == CALL_WITH_STREAM_IO); + } + + template + class Responder; + + template<> + class Responder : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQuery(&grpc_context, &query_info.emplace(), &response_writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + if (!query_info.has_value()) + callback(false); + query_info_ = std::move(query_info).value(); + query_info.reset(); + callback(true); + } + + void write(const GRPCResult &, const CompletionCallback &) override + { + throw Exception("Responder::write() should not be called", ErrorCodes::LOGICAL_ERROR); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + response_writer.Finish(result, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncResponseWriter response_writer{&grpc_context}; + std::optional query_info; + }; + + template<> + class Responder : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQueryWithStreamInput(&grpc_context, &reader, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + reader.Read(&query_info_, getCallbackPtr(callback)); + } + + void write(const GRPCResult &, const CompletionCallback &) override + { + throw Exception("Responder::write() should not be called", ErrorCodes::LOGICAL_ERROR); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + reader.Finish(result, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncReader reader{&grpc_context}; + }; + + template<> + class Responder : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQueryWithStreamOutput(&grpc_context, &query_info.emplace(), &writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + if (!query_info.has_value()) + callback(false); + query_info_ = std::move(query_info).value(); + query_info.reset(); + callback(true); + } + + void write(const GRPCResult & result, const CompletionCallback & callback) override + { + writer.Write(result, getCallbackPtr(callback)); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncWriter writer{&grpc_context}; + std::optional query_info; + }; + + template<> + class Responder : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQueryWithStreamIO(&grpc_context, &reader_writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + reader_writer.Read(&query_info_, getCallbackPtr(callback)); + } + + void write(const GRPCResult & result, const CompletionCallback & callback) override + { + reader_writer.Write(result, getCallbackPtr(callback)); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + reader_writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncReaderWriter reader_writer{&grpc_context}; + }; + + std::unique_ptr makeResponder(CallType call_type) + { + switch (call_type) + { + case CALL_SIMPLE: return std::make_unique>(); + case CALL_WITH_STREAM_INPUT: return std::make_unique>(); + case CALL_WITH_STREAM_OUTPUT: return std::make_unique>(); + case CALL_WITH_STREAM_IO: return std::make_unique>(); + case CALL_MAX: break; + } + __builtin_unreachable(); + } + /// Implementation of ReadBuffer, which just calls a callback. class ReadBufferFromCallback : public ReadBuffer @@ -201,7 +370,7 @@ namespace class Call { public: - Call(std::unique_ptr responder_, IServer & iserver_, Poco::Logger * log_); + Call(CallType call_type_, std::unique_ptr responder_, IServer & iserver_, Poco::Logger * log_); ~Call(); void start(const std::function & on_finish_call_callback); @@ -234,7 +403,8 @@ namespace void throwIfFailedToSendResult(); void sendException(const Exception & exception); - std::unique_ptr responder; + const CallType call_type; + std::unique_ptr responder; IServer & iserver; Poco::Logger * log = nullptr; @@ -284,8 +454,8 @@ namespace std::mutex dummy_mutex; /// Doesn't protect anything. }; - Call::Call(std::unique_ptr responder_, IServer & iserver_, Poco::Logger * log_) - : responder(std::move(responder_)), iserver(iserver_), log(log_) + Call::Call(CallType call_type_, std::unique_ptr responder_, IServer & iserver_, Poco::Logger * log_) + : call_type(call_type_), responder(std::move(responder_)), iserver(iserver_), log(log_) { } @@ -338,7 +508,7 @@ namespace void Call::receiveQuery() { - LOG_INFO(log, "Handling call ExecuteQuery()"); + LOG_INFO(log, "Handling call {}", getCallName(call_type)); readQueryInfo(); @@ -520,6 +690,9 @@ namespace if (!query_info.next_query_info()) break; + if (!isInputStreaming(call_type)) + throw Exception("next_query_info is allowed to be set only for streaming input", ErrorCodes::INVALID_GRPC_QUERY_INFO); + readQueryInfo(); if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty() || !query_info.database().empty() || !query_info.input_data_delimiter().empty() || !query_info.output_format().empty() @@ -668,7 +841,8 @@ namespace LOG_INFO( log, - "Finished call ExecuteQuery() in {} secs. (including reading by client: {}, writing by client: {})", + "Finished call {} in {} secs. (including reading by client: {}, writing by client: {})", + getCallName(call_type), query_time.elapsedSeconds(), static_cast(waited_for_client_reading) / 1000000000ULL, static_cast(waited_for_client_writing) / 1000000000ULL); @@ -773,11 +947,12 @@ namespace if (!values.read_rows && !values.read_bytes && !values.total_rows_to_read && !values.written_rows && !values.written_bytes) return; auto & grpc_progress = *result.mutable_progress(); - grpc_progress.set_read_rows(values.read_rows); - grpc_progress.set_read_bytes(values.read_bytes); - grpc_progress.set_total_rows_to_read(values.total_rows_to_read); - grpc_progress.set_written_rows(values.written_rows); - grpc_progress.set_written_bytes(values.written_bytes); + /// Sum is used because we need to accumulate values for the case if streaming output is disabled. + grpc_progress.set_read_rows(grpc_progress.read_rows() + values.read_rows); + grpc_progress.set_read_bytes(grpc_progress.read_bytes() + values.read_bytes); + grpc_progress.set_total_rows_to_read(grpc_progress.total_rows_to_read() + values.total_rows_to_read); + grpc_progress.set_written_rows(grpc_progress.written_rows() + values.written_rows); + grpc_progress.set_written_bytes(grpc_progress.written_bytes() + values.written_bytes); } void Call::addTotalsToResult(const Block & totals) @@ -867,6 +1042,11 @@ namespace if (responder_finished) return; + /// If output is not streaming then only the final result can be sent. + bool send_final_message = finalize || result.has_exception(); + if (!send_final_message && !isOutputStreaming(call_type)) + return; + /// Wait for previous write to finish. /// (gRPC doesn't allow to start sending another result while the previous is still being sending.) if (sending_result) @@ -879,7 +1059,6 @@ namespace throwIfFailedToSendResult(); /// Start sending the result. - bool send_final_message = finalize || result.has_exception(); LOG_DEBUG(log, "Sending {} result to the client: {}", (send_final_message ? "final" : "intermediate"), getResultDescription(result)); if (write_buffer) @@ -981,14 +1160,19 @@ private: void startReceivingNewCalls() { std::lock_guard lock{mutex}; - makeResponderForNewCall(); + responders_for_new_calls.resize(CALL_MAX); + for (CallType call_type : ext::range(CALL_MAX)) + makeResponderForNewCall(call_type); } - void makeResponderForNewCall() + void makeResponderForNewCall(CallType call_type) { /// `mutex` is already locked. - responder_for_new_call = std::make_unique(); - responder_for_new_call->start(owner.grpc_service, *owner.queue, *owner.queue, [this](bool ok) { onNewCall(ok); }); + responders_for_new_calls[call_type] = makeResponder(call_type); + + responders_for_new_calls[call_type]->start( + owner.grpc_service, *owner.queue, *owner.queue, + [this, call_type](bool ok) { onNewCall(call_type, ok); }); } void stopReceivingNewCalls() @@ -997,18 +1181,18 @@ private: should_stop = true; } - void onNewCall(bool responder_started_ok) + void onNewCall(CallType call_type, bool responder_started_ok) { std::lock_guard lock{mutex}; - auto responder = std::move(responder_for_new_call); + auto responder = std::move(responders_for_new_calls[call_type]); if (should_stop) return; - makeResponderForNewCall(); + makeResponderForNewCall(call_type); if (responder_started_ok) { /// Connection established and the responder has been started. /// So we pass this responder to a Call and make another responder for next connection. - auto new_call = std::make_unique(std::move(responder), owner.iserver, owner.log); + auto new_call = std::make_unique(call_type, std::move(responder), owner.iserver, owner.log); auto * new_call_ptr = new_call.get(); current_calls[new_call_ptr] = std::move(new_call); new_call_ptr->start([this, new_call_ptr]() { onFinishCall(new_call_ptr); }); @@ -1035,8 +1219,14 @@ private: finished_calls.clear(); /// Destroy finished calls. /// If (should_stop == true) we continue processing until there is no active calls. - if (should_stop && current_calls.empty() && !responder_for_new_call) - break; + if (should_stop && current_calls.empty()) + { + bool all_responders_gone = std::all_of( + responders_for_new_calls.begin(), responders_for_new_calls.end(), + [](std::unique_ptr & responder) { return !responder; }); + if (all_responders_gone) + break; + } } bool ok = false; @@ -1054,7 +1244,7 @@ private: GRPCServer & owner; ThreadFromGlobalPool queue_thread; - std::unique_ptr responder_for_new_call; + std::vector> responders_for_new_calls; std::map> current_calls; std::vector> finished_calls; bool should_stop = false; diff --git a/src/Server/grpc_protos/clickhouse_grpc.proto b/src/Server/grpc_protos/clickhouse_grpc.proto index 6d42c2be2de..665f3247dbb 100644 --- a/src/Server/grpc_protos/clickhouse_grpc.proto +++ b/src/Server/grpc_protos/clickhouse_grpc.proto @@ -75,5 +75,8 @@ message Result { } service ClickHouse { - rpc ExecuteQuery(stream QueryInfo) returns (stream Result) {} + rpc ExecuteQuery(QueryInfo) returns (Result) {} + rpc ExecuteQueryWithStreamInput(stream QueryInfo) returns (Result) {} + rpc ExecuteQueryWithStreamOutput(QueryInfo) returns (stream Result) {} + rpc ExecuteQueryWithStreamIO(stream QueryInfo) returns (stream Result) {} } diff --git a/tests/integration/test_grpc_protocol/test.py b/tests/integration/test_grpc_protocol/test.py index 11bac101f76..b1847ec5388 100644 --- a/tests/integration/test_grpc_protocol/test.py +++ b/tests/integration/test_grpc_protocol/test.py @@ -39,20 +39,30 @@ def create_channel(): main_channel = channel return channel -def query_common(query_text, settings={}, input_data=[], input_data_delimiter='', output_format='TabSeparated', query_id='123', session_id='', channel=None): +def query_common(query_text, settings={}, input_data=[], input_data_delimiter='', output_format='TabSeparated', query_id='123', session_id='', stream_output=False, channel=None): if type(input_data) == str: input_data = [input_data] if not channel: channel = main_channel stub = clickhouse_grpc_pb2_grpc.ClickHouseStub(channel) - def send_query_info(): + def query_info(): input_data_part = input_data.pop(0) if input_data else '' - yield clickhouse_grpc_pb2.QueryInfo(query=query_text, settings=settings, input_data=input_data_part, input_data_delimiter=input_data_delimiter, - output_format=output_format, query_id=query_id, session_id=session_id, next_query_info=bool(input_data)) + return clickhouse_grpc_pb2.QueryInfo(query=query_text, settings=settings, input_data=input_data_part, input_data_delimiter=input_data_delimiter, + output_format=output_format, query_id=query_id, session_id=session_id, next_query_info=bool(input_data)) + def send_query_info(): + yield query_info() while input_data: input_data_part = input_data.pop(0) yield clickhouse_grpc_pb2.QueryInfo(input_data=input_data_part, next_query_info=bool(input_data)) - return list(stub.ExecuteQuery(send_query_info())) + stream_input = len(input_data) > 1 + if stream_input and stream_output: + return list(stub.ExecuteQueryWithStreamIO(send_query_info())) + elif stream_input: + return [stub.ExecuteQueryWithStreamInput(send_query_info())] + elif stream_output: + return list(stub.ExecuteQueryWithStreamOutput(query_info())) + else: + return [stub.ExecuteQuery(query_info())] def query_no_errors(*args, **kwargs): results = query_common(*args, **kwargs) @@ -180,7 +190,7 @@ def test_logs(): assert "Peak memory usage" in logs def test_progress(): - results = query_no_errors("SELECT number, sleep(0.31) FROM numbers(8) SETTINGS max_block_size=2, interactive_delay=100000") + results = query_no_errors("SELECT number, sleep(0.31) FROM numbers(8) SETTINGS max_block_size=2, interactive_delay=100000", stream_output=True) #print(results) assert str(results) ==\ """[progress {