Refine the protocol. Code cleanup in tests.

This commit is contained in:
Vitaly Baranov 2020-10-05 23:33:34 +03:00
parent 3856ae1a5e
commit a327f24e3c
5 changed files with 196 additions and 149 deletions

View File

@ -20,10 +20,11 @@
#include <grpc++/server.h>
#include <grpc++/server_builder.h>
using GRPCConnection::QueryRequest;
using GRPCConnection::QueryResponse;
using GRPCConnection::GRPC;
using GRPCService = clickhouse::grpc::ClickHouse::AsyncService;
using GRPCQueryInfo = clickhouse::grpc::QueryInfo;
using GRPCResult = clickhouse::grpc::Result;
using GRPCException = clickhouse::grpc::Exception;
using GRPCProgress = clickhouse::grpc::Progress;
namespace DB
{
@ -39,7 +40,7 @@ namespace
class CommonCallData
{
public:
GRPC::AsyncService * grpc_service;
GRPCService * grpc_service;
grpc::ServerCompletionQueue * notification_cq;
grpc::ServerCompletionQueue * new_call_cq;
grpc::ServerContext grpc_context;
@ -49,7 +50,7 @@ namespace
std::unique_ptr<CommonCallData> next_client;
explicit CommonCallData(
GRPC::AsyncService * grpc_service_,
GRPCService * grpc_service_,
grpc::ServerCompletionQueue * notification_cq_,
grpc::ServerCompletionQueue * new_call_cq_,
IServer * iserver_,
@ -65,7 +66,7 @@ namespace
{
public:
CallDataQuery(
GRPC::AsyncService * grpc_service_,
GRPCService * grpc_service_,
grpc::ServerCompletionQueue * notification_cq_,
grpc::ServerCompletionQueue * new_call_cq_,
IServer * iserver_,
@ -75,7 +76,7 @@ namespace
details_status = SEND_TOTALS;
status = START_QUERY;
out = std::make_shared<WriteBufferFromGRPC>(&responder, static_cast<void *>(this), nullptr);
grpc_service->RequestQuery(&grpc_context, &responder, new_call_cq, notification_cq, this);
grpc_service->RequestExecuteQuery(&grpc_context, &responder, new_call_cq, notification_cq, this);
}
void parseQuery();
void parseData();
@ -116,9 +117,9 @@ namespace
}
private:
QueryRequest request;
QueryResponse response;
grpc::ServerAsyncReaderWriter<QueryResponse, QueryRequest> responder;
GRPCQueryInfo request;
GRPCResult response;
grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> responder;
Stopwatch progress_watch;
Stopwatch query_watch;
@ -180,9 +181,9 @@ namespace
io.onException();
tryLogCurrentException(log);
std::string exception_message = getCurrentExceptionMessage(with_stacktrace, true);
//int exception_code = getCurrentExceptionCode(); //?
response.set_exception_occured(exception_message);
auto & grpc_exception = *response.mutable_exception();
grpc_exception.set_code(getCurrentExceptionCode());
grpc_exception.set_message(getCurrentExceptionMessage(with_stacktrace, true));
status = FINISH_QUERY;
responder.WriteAndFinish(response, grpc::WriteOptions(), grpc::Status(), static_cast<void *>(this));
}
@ -193,52 +194,51 @@ namespace
LOG_TRACE(log, "Process query");
Poco::Net::SocketAddress user_adress(parseGRPCPeer(grpc_context));
LOG_TRACE(log, "Request: {}", request.query_info().query());
LOG_TRACE(log, "Request: {}", request.query());
std::string user = request.user_info().user();
std::string password = request.user_info().password();
std::string quota_key = request.user_info().quota();
interactive_delay = request.interactive_delay();
std::string user = request.user_name();
std::string password = request.password();
std::string quota_key = request.quota();
format_output = "Values";
if (user.empty())
{
user = "default";
password = "";
}
if (interactive_delay == 0)
interactive_delay = INT_MAX;
context.setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); });
query_context = context;
query_scope.emplace(*query_context);
query_context->setUser(user, password, user_adress);
query_context->setCurrentQueryId(request.query_info().query_id());
query_context->setCurrentQueryId(request.query_id());
if (!quota_key.empty())
query_context->setQuotaKey(quota_key);
if (!request.query_info().format().empty())
if (!request.output_format().empty())
{
format_output = request.query_info().format();
query_context->setDefaultFormat(request.query_info().format());
format_output = request.output_format();
query_context->setDefaultFormat(request.output_format());
}
if (!request.query_info().database().empty())
if (!request.database().empty())
{
if (!DatabaseCatalog::instance().isDatabaseExist(request.query_info().database()))
if (!DatabaseCatalog::instance().isDatabaseExist(request.database()))
{
Exception e("Database " + request.query_info().database() + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE);
Exception e("Database " + request.database() + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE);
}
query_context->setCurrentDatabase(request.query_info().database());
query_context->setCurrentDatabase(request.database());
}
SettingsChanges settings_changes;
for (const auto & [key, value] : request.query_info().settings())
for (const auto & [key, value] : request.settings())
{
settings_changes.push_back({key, value});
}
query_context->checkSettingsConstraints(settings_changes);
query_context->applySettingsChanges(settings_changes);
interactive_delay = query_context->getSettingsRef().interactive_delay;
ClientInfo & client_info = query_context->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.interface = ClientInfo::Interface::GRPC;
@ -251,8 +251,8 @@ namespace
void CallDataQuery::parseData()
{
LOG_TRACE(log, "ParseData");
const char * begin = request.query_info().query().data();
const char * end = begin + request.query_info().query().size();
const char * begin = request.query().data();
const char * end = begin + request.query().size();
const Settings & settings = query_context->getSettingsRef();
ParserQuery parser(end);
@ -269,7 +269,7 @@ namespace
io = ::DB::executeQuery(query, *query_context, false, QueryProcessingStage::Complete, true, true);
if (io.out)
{
if (!insert_query || !(insert_query->data || request.query_info().data_stream() || !request.insert_data().empty()))
if (!insert_query || !(insert_query->data || !request.input_data().empty() || request.next_query_info()))
{
Exception e("Logical error: query requires data to insert, but it is not INSERT query", ErrorCodes::NO_DATA_TO_INSERT);
}
@ -289,9 +289,9 @@ namespace
buffers.push_back(data_in_query.get());
}
if (!request.insert_data().empty())
if (!request.input_data().empty())
{
data_in_insert_data = std::make_shared<ReadBufferFromMemory>(request.insert_data().data(), request.insert_data().size());
data_in_insert_data = std::make_shared<ReadBufferFromMemory>(request.input_data().data(), request.input_data().size());
buffers.push_back(data_in_insert_data.get());
}
auto input_buffer_contacenated = std::make_unique<ConcatReadBuffer>(buffers);
@ -309,7 +309,7 @@ namespace
io.out->writePrefix();
while (auto block = res_stream->read())
io.out->write(block);
if (request.query_info().data_stream())
if (request.next_query_info())
{
status = READ_DATA;
responder.Read(&request, static_cast<void *>(this));
@ -323,23 +323,27 @@ namespace
void CallDataQuery::readData()
{
if (request.insert_data().empty())
if (!request.input_data().empty())
{
io.out->writeSuffix();
executeQuery();
}
else
{
const char * begin = request.insert_data().data();
const char * end = begin + request.insert_data().size();
const char * begin = request.input_data().data();
const char * end = begin + request.input_data().size();
ReadBufferFromMemory data_in(begin, end - begin);
auto res_stream = query_context->getInputFormat(
format_input, data_in, io.out->getHeader(), query_context->getSettings().max_insert_block_size);
while (auto block = res_stream->read())
io.out->write(block);
}
if (request.next_query_info())
{
responder.Read(&request, static_cast<void *>(this));
}
else
{
io.out->writeSuffix();
executeQuery();
}
}
void CallDataQuery::executeQuery()
@ -426,7 +430,7 @@ namespace
{
out->setResponse([](const String & buffer)
{
QueryResponse tmp_response;
GRPCResult tmp_response;
tmp_response.set_output(buffer);
return tmp_response;
});
@ -444,7 +448,7 @@ namespace
auto in = std::make_unique<ReadBufferFromString>(buffer);
ProgressValues progress_values;
progress_values.read(*in, DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO);
GRPCConnection::Progress tmp_progress;
GRPCProgress tmp_progress;
tmp_progress.set_read_rows(progress_values.read_rows);
tmp_progress.set_read_bytes(progress_values.read_bytes);
tmp_progress.set_total_rows_to_read(progress_values.total_rows_to_read);
@ -455,8 +459,8 @@ namespace
out->setResponse([&grpc_progress](const String & buffer)
{
QueryResponse tmp_response;
auto tmp_progress = std::make_unique<GRPCConnection::Progress>(grpc_progress(buffer));
GRPCResult tmp_response;
auto tmp_progress = std::make_unique<GRPCProgress>(grpc_progress(buffer));
tmp_response.set_allocated_progress(tmp_progress.release());
return tmp_response;
});
@ -472,7 +476,7 @@ namespace
{
out->setResponse([](const String & buffer)
{
QueryResponse tmp_response;
GRPCResult tmp_response;
tmp_response.set_totals(buffer);
return tmp_response;
});
@ -491,7 +495,7 @@ namespace
{
out->setResponse([](const String & buffer)
{
QueryResponse tmp_response;
GRPCResult tmp_response;
tmp_response.set_extremes(buffer);
return tmp_response;
});

View File

@ -33,12 +33,12 @@ public:
void HandleRpcs();
private:
using GRPC = GRPCConnection::GRPC;
using GRPCService = clickhouse::grpc::ClickHouse::AsyncService;
IServer & iserver;
Poco::Logger * log;
std::unique_ptr<grpc::ServerCompletionQueue> notification_cq;
std::unique_ptr<grpc::ServerCompletionQueue> new_call_cq;
GRPC::AsyncService grpc_service;
GRPCService grpc_service;
std::unique_ptr<grpc::Server> grpc_server;
std::string address_to_listen;
};

View File

@ -11,13 +11,13 @@ namespace DB
class WriteBufferFromGRPC : public BufferWithOwnMemory<WriteBuffer>
{
public:
using QueryRequest = GRPCConnection::QueryRequest;
using QueryResponse = GRPCConnection::QueryResponse;
using GRPCQueryInfo = clickhouse::grpc::QueryInfo;
using GRPCResult = clickhouse::grpc::Result;
WriteBufferFromGRPC(
grpc::ServerAsyncReaderWriter<QueryResponse, QueryRequest> * responder_,
grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> * responder_,
void * tag_,
std::function<QueryResponse(const String & buffer)> set_response_details_)
std::function<GRPCResult(const String & buffer)> set_response_details_)
: responder(responder_), tag(tag_), set_response_details(set_response_details_)
{
}
@ -26,7 +26,7 @@ public:
bool onProgress() { return progress; }
bool isFinished() { return finished; }
void setFinish(bool fl) { finished = fl; }
void setResponse(std::function<QueryResponse(const String & buffer)> function) { set_response_details = function; }
void setResponse(std::function<GRPCResult(const String & buffer)> function) { set_response_details = function; }
void finalize() override
{
progress = false;
@ -35,12 +35,12 @@ public:
}
protected:
grpc::ServerAsyncReaderWriter<QueryResponse, QueryRequest> * responder;
grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> * responder;
void * tag;
bool progress = false;
bool finished = false;
std::function<QueryResponse(const String & buffer)> set_response_details;
std::function<GRPCResult(const String & buffer)> set_response_details;
void nextImpl() override

View File

@ -1,25 +1,18 @@
syntax = "proto3";
package GRPCConnection;
package clickhouse.grpc;
message User {
string user = 1;
string password = 2;
string quota = 3;
}
message QuerySettings {
message QueryInfo {
string query = 1;
string query_id = 2;
bool data_stream = 4;
string database = 5;
string format = 6;
map<string, string> settings = 7;
}
message QueryRequest {
User user_info = 1;
QuerySettings query_info = 2;
string insert_data = 3;
uint64 interactive_delay = 4;
map<string, string> settings = 3;
string database = 4;
string input_data = 5;
string output_format = 6;
string user_name = 7;
string password = 8;
string quota = 9;
bool next_query_info = 10;
}
message Progress {
@ -30,14 +23,19 @@ message Progress {
uint64 written_bytes = 5;
}
message QueryResponse {
string output = 1;
string exception_occured = 2;
Progress progress = 3;
string totals = 4;
string extremes = 5;
message Exception {
int32 code = 1;
string message = 2;
}
service GRPC {
rpc Query(stream QueryRequest) returns (stream QueryResponse) {}
message Result {
string output = 1;
string totals = 2;
string extremes = 3;
Progress progress = 4;
Exception exception = 5;
}
service ClickHouse {
rpc ExecuteQuery(stream QueryInfo) returns (stream Result) {}
}

View File

@ -7,86 +7,131 @@ from helpers.cluster import ClickHouseCluster
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
# Use grpcio-tools to generate *pb2.py files from *.proto.
proto_dir = os.path.join(SCRIPT_DIR, './protos')
proto_gen_dir = os.path.join(SCRIPT_DIR, './_gen')
os.makedirs(proto_gen_dir, exist_ok=True)
subprocess.check_call(
'python3 -m grpc_tools.protoc -I{proto_dir} --python_out={proto_gen_dir} --grpc_python_out={proto_gen_dir} \
{proto_dir}/clickhouse_grpc.proto'.format(proto_dir=proto_dir, proto_gen_dir=proto_gen_dir), shell=True)
# Import everything from the generated *pb2.py files.
sys.path.append(proto_gen_dir)
# Use grpcio-tools to generate *pb2.py files from *.proto.
proto_dir = os.path.join(SCRIPT_DIR, './protos')
gen_dir = os.path.join(SCRIPT_DIR, './_gen')
os.makedirs(gen_dir, exist_ok=True)
subprocess.check_call(
'python3 -m grpc_tools.protoc -I{proto_dir} --python_out={gen_dir} --grpc_python_out={gen_dir} \
{proto_dir}/clickhouse_grpc.proto'.format(proto_dir=proto_dir, gen_dir=gen_dir), shell=True)
sys.path.append(gen_dir)
import clickhouse_grpc_pb2
import clickhouse_grpc_pb2_grpc
# Utilities
config_dir = os.path.join(SCRIPT_DIR, './configs')
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance('node', main_configs=['configs/grpc_port.xml'])
server_port = 9001
grpc_port = 9001
main_channel = None
@pytest.fixture(scope="module")
def server_address():
def create_channel():
node_ip_with_grpc_port = cluster.get_instance_ip('node') + ':' + str(grpc_port)
channel = grpc.insecure_channel(node_ip_with_grpc_port)
grpc.channel_ready_future(channel).result(timeout=10)
global main_channel
if not main_channel:
main_channel = channel
return channel
def query_common(query_text, settings={}, input_data=[], output_format='TabSeparated', query_id='123', channel=None):
if type(input_data) == str:
input_data = [input_data]
if not channel:
channel = main_channel
stub = clickhouse_grpc_pb2_grpc.ClickHouseStub(channel)
def send_query_info():
input_data_part = input_data.pop(0) if input_data else ''
yield clickhouse_grpc_pb2.QueryInfo(query=query_text, settings=settings, input_data=input_data_part, output_format=output_format,
query_id=query_id, next_query_info=bool(input_data))
while input_data:
input_data_part = input_data.pop(0)
yield clickhouse_grpc_pb2.QueryInfo(input_data=input_data_part, next_query_info=bool(input_data))
return list(stub.ExecuteQuery(send_query_info()))
def query_no_errors(*args, **kwargs):
results = query_common(*args, **kwargs)
if results and results[-1].HasField('exception'):
raise Exception(results[-1].exception.message)
return results
def query(*args, **kwargs):
output = ""
for result in query_no_errors(*args, **kwargs):
output += result.output
return output
def query_and_get_error(*args, **kwargs):
results = query_common(*args, **kwargs)
if not results or not results[-1].HasField('exception'):
raise Exception("Expected to be failed but succeeded!")
return results[-1].exception
def query_and_get_totals(*args, **kwargs):
totals = ""
for result in query_no_errors(*args, **kwargs):
totals += result.totals
return totals
def query_and_get_extremes(*args, **kwargs):
extremes = ""
for result in query_no_errors(*args, **kwargs):
extremes += result.extremes
return extremes
@pytest.fixture(scope="module", autouse=True)
def start_cluster():
cluster.start()
try:
yield cluster.get_instance_ip('node')
with create_channel() as channel:
yield cluster
finally:
cluster.shutdown()
def Query(server_address_and_port, query, mode="output", insert_data=[]):
output = []
totals = []
data_stream = (len(insert_data) != 0)
with grpc.insecure_channel(server_address_and_port) as channel:
grpc.channel_ready_future(channel).result()
stub = clickhouse_grpc_pb2_grpc.GRPCStub(channel)
def write_query():
user_info = clickhouse_grpc_pb2.User(user="default", quota='default')
query_info = clickhouse_grpc_pb2.QuerySettings(query=query, query_id='123', data_stream=data_stream, format='TabSeparated')
yield clickhouse_grpc_pb2.QueryRequest(user_info=user_info, query_info=query_info)
if data_stream:
for data in insert_data:
yield clickhouse_grpc_pb2.QueryRequest(insert_data=data)
yield clickhouse_grpc_pb2.QueryRequest(insert_data="")
for response in stub.Query(write_query(), 10.0):
output += response.output.split()
totals += response.totals.split()
if mode == "output":
return output
elif mode == "totals":
return totals
@pytest.fixture(autouse=True)
def reset_after_test():
yield
query("DROP TABLE IF EXISTS t")
def test_ordinary_query(server_address):
server_address_and_port = server_address + ':' + str(server_port)
assert Query(server_address_and_port, "SELECT 1") == [u'1']
assert Query(server_address_and_port, "SELECT count() FROM numbers(100)") == [u'100']
# Actual tests
def test_query_insert(server_address):
server_address_and_port = server_address + ':' + str(server_port)
assert Query(server_address_and_port, "CREATE TABLE t (a UInt8) ENGINE = Memory") == []
assert Query(server_address_and_port, "INSERT INTO t VALUES (1),(2),(3)") == []
assert Query(server_address_and_port, "INSERT INTO t FORMAT TabSeparated 4\n5\n6\n") == []
assert Query(server_address_and_port, "INSERT INTO t FORMAT TabSeparated 10\n11\n12\n") == []
assert Query(server_address_and_port, "SELECT a FROM t ORDER BY a") == [u'1', u'2', u'3', u'4', u'5', u'6', u'10', u'11', u'12']
assert Query(server_address_and_port, "DROP TABLE t") == []
def test_select_one():
assert query("SELECT 1") == "1\n"
def test_handle_mistakes(server_address):
server_address_and_port = server_address + ':' + str(server_port)
assert Query(server_address_and_port, "") == []
assert Query(server_address_and_port, "CREATE TABLE t (a UInt8) ENGINE = Memory") == []
assert Query(server_address_and_port, "CREATE TABLE t (a UInt8) ENGINE = Memory") == []
def test_ordinary_query():
assert query("SELECT count() FROM numbers(100)") == "100\n"
def test_totals(server_address):
server_address_and_port = server_address + ':' + str(server_port)
assert Query(server_address_and_port, "") == []
assert Query(server_address_and_port, "CREATE TABLE tabl (x UInt8, y UInt8) ENGINE = Memory;") == []
assert Query(server_address_and_port, "INSERT INTO tabl VALUES (1, 2), (2, 4), (3, 2), (3, 3), (3, 4);") == []
assert Query(server_address_and_port, "SELECT sum(x), y FROM tabl GROUP BY y WITH TOTALS") == [u'4', u'2', u'3', u'3', u'5', u'4']
assert Query(server_address_and_port, "SELECT sum(x), y FROM tabl GROUP BY y WITH TOTALS", mode="totals") == [u'12', u'0']
def test_insert_query():
query("CREATE TABLE t (a UInt8) ENGINE = Memory")
query("INSERT INTO t VALUES (1),(2),(3)")
query("INSERT INTO t FORMAT TabSeparated 4\n5\n6\n")
query("INSERT INTO t VALUES", input_data="(7),(8)")
query("INSERT INTO t FORMAT TabSeparated", input_data="9\n10\n")
assert query("SELECT a FROM t ORDER BY a") == "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n"
def test_query_insert(server_address):
server_address_and_port = server_address + ':' + str(server_port)
assert Query(server_address_and_port, "CREATE TABLE t (a UInt8) ENGINE = Memory") == []
assert Query(server_address_and_port, "INSERT INTO t VALUES", insert_data=["(1),(2),(3)", "(5),(4),(6)", "(8),(7),(9)"]) == []
assert Query(server_address_and_port, "SELECT a FROM t ORDER BY a") == [u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']
assert Query(server_address_and_port, "DROP TABLE t") == []
def test_insert_query_streaming():
query("CREATE TABLE t (a UInt8) ENGINE = Memory")
query("INSERT INTO t VALUES", input_data=["(1),(2),(3)", "(5),(4),(6)", "(8),(7),(9)"])
assert query("SELECT a FROM t ORDER BY a") == "1\n2\n3\n4\n5\n6\n7\n8\n9\n"
def test_totals_and_extremes():
query("CREATE TABLE t (x UInt8, y UInt8) ENGINE = Memory")
query("INSERT INTO t VALUES (1, 2), (2, 4), (3, 2), (3, 3), (3, 4)")
assert query("SELECT sum(x), y FROM t GROUP BY y WITH TOTALS") == "4\t2\n3\t3\n5\t4\n"
assert query_and_get_totals("SELECT sum(x), y FROM t GROUP BY y WITH TOTALS") == "12\t0\n"
assert query("SELECT x, y FROM t") == "1\t2\n2\t4\n3\t2\n3\t3\n3\t4\n"
assert query_and_get_extremes("SELECT x, y FROM t", settings={"extremes": "1"}) == "1\t2\n3\t4\n"
def test_errors_handling():
e = query_and_get_error("")
#print(e)
assert "Empty query" in e.message
query("CREATE TABLE t (a UInt8) ENGINE = Memory")
e = query_and_get_error("CREATE TABLE t (a UInt8) ENGINE = Memory")
assert "Table default.t already exists" in e.message