Extend the protocol with streaming and nonstreaming functions.

This commit is contained in:
Vitaly Baranov 2020-10-25 01:03:49 +03:00
parent b51e14253d
commit b0cb3eb306
3 changed files with 261 additions and 58 deletions

View File

@ -39,6 +39,7 @@ namespace ErrorCodes
{
extern const int INVALID_GRPC_QUERY_INFO;
extern const int INVALID_SESSION_TIMEOUT;
extern const int LOGICAL_ERROR;
extern const int NETWORK_ERROR;
extern const int NO_DATA_TO_INSERT;
extern const int UNKNOWN_DATABASE;
@ -112,40 +113,23 @@ namespace
using CompletionCallback = std::function<void(bool)>;
/// Requests a connection and provides low-level interface for reading and writing.
class Responder
class BaseResponder
{
public:
void start(
GRPCService & grpc_service,
virtual ~BaseResponder() = default;
virtual void start(GRPCService & grpc_service,
grpc::ServerCompletionQueue & new_call_queue,
grpc::ServerCompletionQueue & notification_queue,
const CompletionCallback & callback)
{
grpc_service.RequestExecuteQuery(&grpc_context, &reader_writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
}
const CompletionCallback & callback) = 0;
void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback)
{
reader_writer.Read(&query_info_, getCallbackPtr(callback));
}
virtual void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) = 0;
virtual void write(const GRPCResult & result, const CompletionCallback & callback) = 0;
virtual void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) = 0;
void write(const GRPCResult & result, const CompletionCallback & callback)
{
reader_writer.Write(result, getCallbackPtr(callback));
}
Poco::Net::SocketAddress getClientAddress() const { String peer = grpc_context.peer(); return Poco::Net::SocketAddress{peer.substr(peer.find(':') + 1)}; }
void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback)
{
reader_writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback));
}
Poco::Net::SocketAddress getClientAddress() const
{
String peer = grpc_context.peer();
return Poco::Net::SocketAddress{peer.substr(peer.find(':') + 1)};
}
private:
protected:
CompletionCallback * getCallbackPtr(const CompletionCallback & callback)
{
/// It would be better to pass callbacks to gRPC calls.
@ -166,13 +150,198 @@ namespace
};
return &callback_in_map;
}
grpc::ServerContext grpc_context;
private:
grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> reader_writer{&grpc_context};
std::unordered_map<size_t, CompletionCallback> callbacks;
size_t next_callback_id = 0;
std::mutex mutex;
};
enum CallType
{
CALL_SIMPLE, /// ExecuteQuery() call
CALL_WITH_STREAM_INPUT, /// ExecuteQueryWithStreamInput() call
CALL_WITH_STREAM_OUTPUT, /// ExecuteQueryWithStreamOutput() call
CALL_WITH_STREAM_IO, /// ExecuteQueryWithStreamIO() call
CALL_MAX,
};
const char * getCallName(CallType call_type)
{
switch (call_type)
{
case CALL_SIMPLE: return "ExecuteQuery()";
case CALL_WITH_STREAM_INPUT: return "ExecuteQueryWithStreamInput()";
case CALL_WITH_STREAM_OUTPUT: return "ExecuteQueryWithStreamOutput()";
case CALL_WITH_STREAM_IO: return "ExecuteQueryWithStreamIO()";
case CALL_MAX: break;
}
__builtin_unreachable();
}
bool isInputStreaming(CallType call_type)
{
return (call_type == CALL_WITH_STREAM_INPUT) || (call_type == CALL_WITH_STREAM_IO);
}
bool isOutputStreaming(CallType call_type)
{
return (call_type == CALL_WITH_STREAM_OUTPUT) || (call_type == CALL_WITH_STREAM_IO);
}
template <enum CallType call_type>
class Responder;
template<>
class Responder<CALL_SIMPLE> : public BaseResponder
{
public:
void start(GRPCService & grpc_service,
grpc::ServerCompletionQueue & new_call_queue,
grpc::ServerCompletionQueue & notification_queue,
const CompletionCallback & callback) override
{
grpc_service.RequestExecuteQuery(&grpc_context, &query_info.emplace(), &response_writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
}
void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
{
if (!query_info.has_value())
callback(false);
query_info_ = std::move(query_info).value();
query_info.reset();
callback(true);
}
void write(const GRPCResult &, const CompletionCallback &) override
{
throw Exception("Responder<CALL_SIMPLE>::write() should not be called", ErrorCodes::LOGICAL_ERROR);
}
void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
{
response_writer.Finish(result, status, getCallbackPtr(callback));
}
private:
grpc::ServerAsyncResponseWriter<GRPCResult> response_writer{&grpc_context};
std::optional<GRPCQueryInfo> query_info;
};
template<>
class Responder<CALL_WITH_STREAM_INPUT> : public BaseResponder
{
public:
void start(GRPCService & grpc_service,
grpc::ServerCompletionQueue & new_call_queue,
grpc::ServerCompletionQueue & notification_queue,
const CompletionCallback & callback) override
{
grpc_service.RequestExecuteQueryWithStreamInput(&grpc_context, &reader, &new_call_queue, &notification_queue, getCallbackPtr(callback));
}
void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
{
reader.Read(&query_info_, getCallbackPtr(callback));
}
void write(const GRPCResult &, const CompletionCallback &) override
{
throw Exception("Responder<CALL_WITH_STREAM_INPUT>::write() should not be called", ErrorCodes::LOGICAL_ERROR);
}
void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
{
reader.Finish(result, status, getCallbackPtr(callback));
}
private:
grpc::ServerAsyncReader<GRPCResult, GRPCQueryInfo> reader{&grpc_context};
};
template<>
class Responder<CALL_WITH_STREAM_OUTPUT> : public BaseResponder
{
public:
void start(GRPCService & grpc_service,
grpc::ServerCompletionQueue & new_call_queue,
grpc::ServerCompletionQueue & notification_queue,
const CompletionCallback & callback) override
{
grpc_service.RequestExecuteQueryWithStreamOutput(&grpc_context, &query_info.emplace(), &writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
}
void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
{
if (!query_info.has_value())
callback(false);
query_info_ = std::move(query_info).value();
query_info.reset();
callback(true);
}
void write(const GRPCResult & result, const CompletionCallback & callback) override
{
writer.Write(result, getCallbackPtr(callback));
}
void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
{
writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback));
}
private:
grpc::ServerAsyncWriter<GRPCResult> writer{&grpc_context};
std::optional<GRPCQueryInfo> query_info;
};
template<>
class Responder<CALL_WITH_STREAM_IO> : public BaseResponder
{
public:
void start(GRPCService & grpc_service,
grpc::ServerCompletionQueue & new_call_queue,
grpc::ServerCompletionQueue & notification_queue,
const CompletionCallback & callback) override
{
grpc_service.RequestExecuteQueryWithStreamIO(&grpc_context, &reader_writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
}
void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
{
reader_writer.Read(&query_info_, getCallbackPtr(callback));
}
void write(const GRPCResult & result, const CompletionCallback & callback) override
{
reader_writer.Write(result, getCallbackPtr(callback));
}
void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
{
reader_writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback));
}
private:
grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> reader_writer{&grpc_context};
};
std::unique_ptr<BaseResponder> makeResponder(CallType call_type)
{
switch (call_type)
{
case CALL_SIMPLE: return std::make_unique<Responder<CALL_SIMPLE>>();
case CALL_WITH_STREAM_INPUT: return std::make_unique<Responder<CALL_WITH_STREAM_INPUT>>();
case CALL_WITH_STREAM_OUTPUT: return std::make_unique<Responder<CALL_WITH_STREAM_OUTPUT>>();
case CALL_WITH_STREAM_IO: return std::make_unique<Responder<CALL_WITH_STREAM_IO>>();
case CALL_MAX: break;
}
__builtin_unreachable();
}
/// Implementation of ReadBuffer, which just calls a callback.
class ReadBufferFromCallback : public ReadBuffer
@ -201,7 +370,7 @@ namespace
class Call
{
public:
Call(std::unique_ptr<Responder> responder_, IServer & iserver_, Poco::Logger * log_);
Call(CallType call_type_, std::unique_ptr<BaseResponder> responder_, IServer & iserver_, Poco::Logger * log_);
~Call();
void start(const std::function<void(void)> & on_finish_call_callback);
@ -234,7 +403,8 @@ namespace
void throwIfFailedToSendResult();
void sendException(const Exception & exception);
std::unique_ptr<Responder> responder;
const CallType call_type;
std::unique_ptr<BaseResponder> responder;
IServer & iserver;
Poco::Logger * log = nullptr;
@ -284,8 +454,8 @@ namespace
std::mutex dummy_mutex; /// Doesn't protect anything.
};
Call::Call(std::unique_ptr<Responder> responder_, IServer & iserver_, Poco::Logger * log_)
: responder(std::move(responder_)), iserver(iserver_), log(log_)
Call::Call(CallType call_type_, std::unique_ptr<BaseResponder> responder_, IServer & iserver_, Poco::Logger * log_)
: call_type(call_type_), responder(std::move(responder_)), iserver(iserver_), log(log_)
{
}
@ -338,7 +508,7 @@ namespace
void Call::receiveQuery()
{
LOG_INFO(log, "Handling call ExecuteQuery()");
LOG_INFO(log, "Handling call {}", getCallName(call_type));
readQueryInfo();
@ -520,6 +690,9 @@ namespace
if (!query_info.next_query_info())
break;
if (!isInputStreaming(call_type))
throw Exception("next_query_info is allowed to be set only for streaming input", ErrorCodes::INVALID_GRPC_QUERY_INFO);
readQueryInfo();
if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty()
|| !query_info.database().empty() || !query_info.input_data_delimiter().empty() || !query_info.output_format().empty()
@ -668,7 +841,8 @@ namespace
LOG_INFO(
log,
"Finished call ExecuteQuery() in {} secs. (including reading by client: {}, writing by client: {})",
"Finished call {} in {} secs. (including reading by client: {}, writing by client: {})",
getCallName(call_type),
query_time.elapsedSeconds(),
static_cast<double>(waited_for_client_reading) / 1000000000ULL,
static_cast<double>(waited_for_client_writing) / 1000000000ULL);
@ -773,11 +947,12 @@ namespace
if (!values.read_rows && !values.read_bytes && !values.total_rows_to_read && !values.written_rows && !values.written_bytes)
return;
auto & grpc_progress = *result.mutable_progress();
grpc_progress.set_read_rows(values.read_rows);
grpc_progress.set_read_bytes(values.read_bytes);
grpc_progress.set_total_rows_to_read(values.total_rows_to_read);
grpc_progress.set_written_rows(values.written_rows);
grpc_progress.set_written_bytes(values.written_bytes);
/// Sum is used because we need to accumulate values for the case if streaming output is disabled.
grpc_progress.set_read_rows(grpc_progress.read_rows() + values.read_rows);
grpc_progress.set_read_bytes(grpc_progress.read_bytes() + values.read_bytes);
grpc_progress.set_total_rows_to_read(grpc_progress.total_rows_to_read() + values.total_rows_to_read);
grpc_progress.set_written_rows(grpc_progress.written_rows() + values.written_rows);
grpc_progress.set_written_bytes(grpc_progress.written_bytes() + values.written_bytes);
}
void Call::addTotalsToResult(const Block & totals)
@ -867,6 +1042,11 @@ namespace
if (responder_finished)
return;
/// If output is not streaming then only the final result can be sent.
bool send_final_message = finalize || result.has_exception();
if (!send_final_message && !isOutputStreaming(call_type))
return;
/// Wait for previous write to finish.
/// (gRPC doesn't allow to start sending another result while the previous is still being sending.)
if (sending_result)
@ -879,7 +1059,6 @@ namespace
throwIfFailedToSendResult();
/// Start sending the result.
bool send_final_message = finalize || result.has_exception();
LOG_DEBUG(log, "Sending {} result to the client: {}", (send_final_message ? "final" : "intermediate"), getResultDescription(result));
if (write_buffer)
@ -981,14 +1160,19 @@ private:
void startReceivingNewCalls()
{
std::lock_guard lock{mutex};
makeResponderForNewCall();
responders_for_new_calls.resize(CALL_MAX);
for (CallType call_type : ext::range(CALL_MAX))
makeResponderForNewCall(call_type);
}
void makeResponderForNewCall()
void makeResponderForNewCall(CallType call_type)
{
/// `mutex` is already locked.
responder_for_new_call = std::make_unique<Responder>();
responder_for_new_call->start(owner.grpc_service, *owner.queue, *owner.queue, [this](bool ok) { onNewCall(ok); });
responders_for_new_calls[call_type] = makeResponder(call_type);
responders_for_new_calls[call_type]->start(
owner.grpc_service, *owner.queue, *owner.queue,
[this, call_type](bool ok) { onNewCall(call_type, ok); });
}
void stopReceivingNewCalls()
@ -997,18 +1181,18 @@ private:
should_stop = true;
}
void onNewCall(bool responder_started_ok)
void onNewCall(CallType call_type, bool responder_started_ok)
{
std::lock_guard lock{mutex};
auto responder = std::move(responder_for_new_call);
auto responder = std::move(responders_for_new_calls[call_type]);
if (should_stop)
return;
makeResponderForNewCall();
makeResponderForNewCall(call_type);
if (responder_started_ok)
{
/// Connection established and the responder has been started.
/// So we pass this responder to a Call and make another responder for next connection.
auto new_call = std::make_unique<Call>(std::move(responder), owner.iserver, owner.log);
auto new_call = std::make_unique<Call>(call_type, std::move(responder), owner.iserver, owner.log);
auto * new_call_ptr = new_call.get();
current_calls[new_call_ptr] = std::move(new_call);
new_call_ptr->start([this, new_call_ptr]() { onFinishCall(new_call_ptr); });
@ -1035,9 +1219,15 @@ private:
finished_calls.clear(); /// Destroy finished calls.
/// If (should_stop == true) we continue processing until there is no active calls.
if (should_stop && current_calls.empty() && !responder_for_new_call)
if (should_stop && current_calls.empty())
{
bool all_responders_gone = std::all_of(
responders_for_new_calls.begin(), responders_for_new_calls.end(),
[](std::unique_ptr<BaseResponder> & responder) { return !responder; });
if (all_responders_gone)
break;
}
}
bool ok = false;
void * tag = nullptr;
@ -1054,7 +1244,7 @@ private:
GRPCServer & owner;
ThreadFromGlobalPool queue_thread;
std::unique_ptr<Responder> responder_for_new_call;
std::vector<std::unique_ptr<BaseResponder>> responders_for_new_calls;
std::map<Call *, std::unique_ptr<Call>> current_calls;
std::vector<std::unique_ptr<Call>> finished_calls;
bool should_stop = false;

View File

@ -75,5 +75,8 @@ message Result {
}
service ClickHouse {
rpc ExecuteQuery(stream QueryInfo) returns (stream Result) {}
rpc ExecuteQuery(QueryInfo) returns (Result) {}
rpc ExecuteQueryWithStreamInput(stream QueryInfo) returns (Result) {}
rpc ExecuteQueryWithStreamOutput(QueryInfo) returns (stream Result) {}
rpc ExecuteQueryWithStreamIO(stream QueryInfo) returns (stream Result) {}
}

View File

@ -39,20 +39,30 @@ def create_channel():
main_channel = channel
return channel
def query_common(query_text, settings={}, input_data=[], input_data_delimiter='', output_format='TabSeparated', query_id='123', session_id='', channel=None):
def query_common(query_text, settings={}, input_data=[], input_data_delimiter='', output_format='TabSeparated', query_id='123', session_id='', stream_output=False, channel=None):
if type(input_data) == str:
input_data = [input_data]
if not channel:
channel = main_channel
stub = clickhouse_grpc_pb2_grpc.ClickHouseStub(channel)
def send_query_info():
def query_info():
input_data_part = input_data.pop(0) if input_data else ''
yield clickhouse_grpc_pb2.QueryInfo(query=query_text, settings=settings, input_data=input_data_part, input_data_delimiter=input_data_delimiter,
return clickhouse_grpc_pb2.QueryInfo(query=query_text, settings=settings, input_data=input_data_part, input_data_delimiter=input_data_delimiter,
output_format=output_format, query_id=query_id, session_id=session_id, next_query_info=bool(input_data))
def send_query_info():
yield query_info()
while input_data:
input_data_part = input_data.pop(0)
yield clickhouse_grpc_pb2.QueryInfo(input_data=input_data_part, next_query_info=bool(input_data))
return list(stub.ExecuteQuery(send_query_info()))
stream_input = len(input_data) > 1
if stream_input and stream_output:
return list(stub.ExecuteQueryWithStreamIO(send_query_info()))
elif stream_input:
return [stub.ExecuteQueryWithStreamInput(send_query_info())]
elif stream_output:
return list(stub.ExecuteQueryWithStreamOutput(query_info()))
else:
return [stub.ExecuteQuery(query_info())]
def query_no_errors(*args, **kwargs):
results = query_common(*args, **kwargs)
@ -180,7 +190,7 @@ def test_logs():
assert "Peak memory usage" in logs
def test_progress():
results = query_no_errors("SELECT number, sleep(0.31) FROM numbers(8) SETTINGS max_block_size=2, interactive_delay=100000")
results = query_no_errors("SELECT number, sleep(0.31) FROM numbers(8) SETTINGS max_block_size=2, interactive_delay=100000", stream_output=True)
#print(results)
assert str(results) ==\
"""[progress {