Added support for sessions in gRPC protocol.

This commit is contained in:
Vitaly Baranov 2020-10-24 19:57:27 +03:00
parent 9285f7edc1
commit b51e14253d
3 changed files with 56 additions and 4 deletions

View File

@ -21,6 +21,7 @@
#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
#include <Server/IServer.h>
#include <Storages/IStorage.h>
#include <Poco/Util/LayeredConfiguration.h>
#include <grpc++/security/server_credentials.h>
#include <grpc++/server.h>
#include <grpc++/server_builder.h>
@ -37,6 +38,7 @@ namespace DB
namespace ErrorCodes
{
extern const int INVALID_GRPC_QUERY_INFO;
extern const int INVALID_SESSION_TIMEOUT;
extern const int NETWORK_ERROR;
extern const int NO_DATA_TO_INSERT;
extern const int UNKNOWN_DATABASE;
@ -44,6 +46,24 @@ namespace ErrorCodes
namespace
{
/// Gets session's timeout from query info or from the server config.
std::chrono::steady_clock::duration getSessionTimeout(const GRPCQueryInfo & query_info, const Poco::Util::AbstractConfiguration & config)
{
auto session_timeout = query_info.session_timeout();
if (session_timeout)
{
auto max_session_timeout = config.getUInt("max_session_timeout", 3600);
if (session_timeout > max_session_timeout)
throw Exception(
"Session timeout '" + std::to_string(session_timeout) + "' is larger than max_session_timeout: "
+ std::to_string(max_session_timeout) + ". Maximum session timeout could be modified in configuration file.",
ErrorCodes::INVALID_SESSION_TIMEOUT);
}
else
session_timeout = config.getInt("default_session_timeout", 60);
return std::chrono::seconds(session_timeout);
}
/// Generates a description of a query by a specified query info.
/// This description is used for logging only.
String getQueryDescription(const GRPCQueryInfo & query_info)
@ -218,6 +238,7 @@ namespace
IServer & iserver;
Poco::Logger * log = nullptr;
std::shared_ptr<NamedSession> session;
std::optional<Context> query_context;
std::optional<CurrentThread::QueryScope> query_scope;
String query_text;
@ -348,6 +369,16 @@ namespace
if (!quota_key.empty())
query_context->setQuotaKey(quota_key);
/// The user could specify session identifier and session timeout.
/// It allows to modify settings, create temporary tables and reuse them in subsequent requests.
if (!query_info.session_id().empty())
{
session = query_context->acquireNamedSession(
query_info.session_id(), getSessionTimeout(query_info, iserver.config()), query_info.session_check());
query_context = session->context;
query_context->setSessionContext(session->context);
}
/// Set client info.
ClientInfo & client_info = query_context->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
@ -492,7 +523,8 @@ namespace
readQueryInfo();
if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty()
|| !query_info.database().empty() || !query_info.input_data_delimiter().empty() || !query_info.output_format().empty()
|| !query_info.user_name().empty() || !query_info.password().empty() || !query_info.quota().empty())
|| !query_info.user_name().empty() || !query_info.password().empty() || !query_info.quota().empty()
|| !query_info.session_id().empty())
{
throw Exception("Extra query infos can be used only to add more input data. "
"Only the following fields can be set: input_data, next_query_info",
@ -699,6 +731,9 @@ namespace
io = {};
query_scope.reset();
query_context.reset();
if (session)
session->release();
session.reset();
}
void Call::readQueryInfo()

View File

@ -13,7 +13,10 @@ message QueryInfo {
string user_name = 8;
string password = 9;
string quota = 10;
bool next_query_info = 11;
string session_id = 11;
bool session_check = 12;
uint32 session_timeout = 13;
bool next_query_info = 14;
}
enum LogsLevel {

View File

@ -39,7 +39,7 @@ def create_channel():
main_channel = channel
return channel
def query_common(query_text, settings={}, input_data=[], input_data_delimiter='', output_format='TabSeparated', query_id='123', channel=None):
def query_common(query_text, settings={}, input_data=[], input_data_delimiter='', output_format='TabSeparated', query_id='123', session_id='', channel=None):
if type(input_data) == str:
input_data = [input_data]
if not channel:
@ -48,7 +48,7 @@ def query_common(query_text, settings={}, input_data=[], input_data_delimiter=''
def send_query_info():
input_data_part = input_data.pop(0) if input_data else ''
yield clickhouse_grpc_pb2.QueryInfo(query=query_text, settings=settings, input_data=input_data_part, input_data_delimiter=input_data_delimiter,
output_format=output_format, query_id=query_id, next_query_info=bool(input_data))
output_format=output_format, query_id=query_id, session_id=session_id, next_query_info=bool(input_data))
while input_data:
input_data_part = input_data.pop(0)
yield clickhouse_grpc_pb2.QueryInfo(input_data=input_data_part, next_query_info=bool(input_data))
@ -212,3 +212,17 @@ def test_progress():
rows_before_limit: 8
}
]"""
def test_session():
session_a = "session A"
session_b = "session B"
query("SET custom_x=1", session_id=session_a)
query("SET custom_y=2", session_id=session_a)
query("SET custom_x=3", session_id=session_b)
query("SET custom_y=4", session_id=session_b)
assert query("SELECT getSetting('custom_x'), getSetting('custom_y')", session_id=session_a) == "1\t2\n"
assert query("SELECT getSetting('custom_x'), getSetting('custom_y')", session_id=session_b) == "3\t4\n"
def test_no_session():
e = query_and_get_error("SET custom_x=1")
assert "There is no session" in e.display_text