Code cleanups and improvements.

This commit is contained in:
Vitaly Baranov 2021-08-01 17:12:34 +03:00
parent 51ffc33457
commit fabd7193bd
27 changed files with 677 additions and 682 deletions

View File

@ -9,7 +9,6 @@
#include <Databases/DatabaseMemory.h> #include <Databases/DatabaseMemory.h>
#include <Storages/System/attachSystemTables.h> #include <Storages/System/attachSystemTables.h>
#include <Interpreters/ProcessList.h> #include <Interpreters/ProcessList.h>
#include <Interpreters/Session.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
#include <Interpreters/loadMetadata.h> #include <Interpreters/loadMetadata.h>
#include <Interpreters/DatabaseCatalog.h> #include <Interpreters/DatabaseCatalog.h>
@ -377,11 +376,13 @@ void LocalServer::processQueries()
/// we can't mutate global global_context (can lead to races, as it was already passed to some background threads) /// we can't mutate global global_context (can lead to races, as it was already passed to some background threads)
/// so we can't reuse it safely as a query context and need a copy here /// so we can't reuse it safely as a query context and need a copy here
Session session(global_context, ClientInfo::Interface::TCP); auto context = Context::createCopy(global_context);
session.setUser("default", "", Poco::Net::SocketAddress{});
auto context = session.makeQueryContext(""); context->makeSessionContext();
context->makeQueryContext();
context->authenticate("default", "", Poco::Net::SocketAddress{});
context->setCurrentQueryId("");
applyCmdSettings(context); applyCmdSettings(context);
/// Use the same query_id (and thread group) for all queries /// Use the same query_id (and thread group) for all queries

View File

@ -47,13 +47,13 @@
#include <Interpreters/ExternalDictionariesLoader.h> #include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/ExternalModelsLoader.h> #include <Interpreters/ExternalModelsLoader.h>
#include <Interpreters/ProcessList.h> #include <Interpreters/ProcessList.h>
#include <Interpreters/Session.h>
#include <Interpreters/loadMetadata.h> #include <Interpreters/loadMetadata.h>
#include <Interpreters/DatabaseCatalog.h> #include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/DNSCacheUpdater.h> #include <Interpreters/DNSCacheUpdater.h>
#include <Interpreters/ExternalLoaderXMLConfigRepository.h> #include <Interpreters/ExternalLoaderXMLConfigRepository.h>
#include <Interpreters/InterserverCredentials.h> #include <Interpreters/InterserverCredentials.h>
#include <Interpreters/JIT/CompiledExpressionCache.h> #include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Interpreters/Session.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Storages/StorageReplicatedMergeTree.h> #include <Storages/StorageReplicatedMergeTree.h>
#include <Storages/System/attachSystemTables.h> #include <Storages/System/attachSystemTables.h>
@ -1429,7 +1429,7 @@ if (ThreadFuzzer::instance().isEffective())
/// Must be done after initialization of `servers`, because async_metrics will access `servers` variable from its thread. /// Must be done after initialization of `servers`, because async_metrics will access `servers` variable from its thread.
async_metrics.start(); async_metrics.start();
Session::enableNamedSessions(); Session::startupNamedSessions();
{ {
String level_str = config().getString("text_log.level", ""); String level_str = config().getString("text_log.level", "");

View File

@ -70,6 +70,7 @@ public:
/// Returns the current user. The function can return nullptr. /// Returns the current user. The function can return nullptr.
UserPtr getUser() const; UserPtr getUser() const;
String getUserName() const; String getUserName() const;
std::optional<UUID> getUserID() const { return getParams().user_id; }
/// Returns information about current and enabled roles. /// Returns information about current and enabled roles.
std::shared_ptr<const EnabledRolesInfo> getRolesInfo() const; std::shared_ptr<const EnabledRolesInfo> getRolesInfo() const;

View File

@ -26,6 +26,8 @@ protected:
String user_name; String user_name;
}; };
/// Does not check the password/credentials and that the specified host is allowed.
/// (Used only internally in cluster, if the secret matches)
class AlwaysAllowCredentials class AlwaysAllowCredentials
: public Credentials : public Credentials
{ {

View File

@ -5,6 +5,7 @@
#include <Poco/Net/HTTPRequest.h> #include <Poco/Net/HTTPRequest.h>
#include <Poco/URI.h> #include <Poco/URI.h>
#include <filesystem> #include <filesystem>
#include <thread>
namespace fs = std::filesystem; namespace fs = std::filesystem;

View File

@ -2,8 +2,6 @@
#include <Core/MySQL/PacketsConnection.h> #include <Core/MySQL/PacketsConnection.h>
#include <Poco/RandomStream.h> #include <Poco/RandomStream.h>
#include <Poco/SHA1Engine.h> #include <Poco/SHA1Engine.h>
#include <Access/User.h>
#include <Access/AccessControlManager.h>
#include <Interpreters/Session.h> #include <Interpreters/Session.h>
#include <common/logger_useful.h> #include <common/logger_useful.h>
@ -74,7 +72,7 @@ Native41::Native41(const String & password, const String & auth_plugin_data)
} }
void Native41::authenticate( void Native41::authenticate(
const String & user_name, std::optional<String> auth_response, Session & session, const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool, const Poco::Net::SocketAddress & address) std::shared_ptr<PacketEndpoint> packet_endpoint, bool, const Poco::Net::SocketAddress & address)
{ {
if (!auth_response) if (!auth_response)
@ -87,7 +85,7 @@ void Native41::authenticate(
if (auth_response->empty()) if (auth_response->empty())
{ {
session.setUser(user_name, "", address); session.authenticate(user_name, "", address);
return; return;
} }
@ -97,9 +95,7 @@ void Native41::authenticate(
+ " bytes, received: " + std::to_string(auth_response->size()) + " bytes.", + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
ErrorCodes::UNKNOWN_EXCEPTION); ErrorCodes::UNKNOWN_EXCEPTION);
const auto user_authentication = session.getUserAuthentication(user_name); Poco::SHA1Engine::Digest double_sha1_value = session.getPasswordDoubleSHA1(user_name);
Poco::SHA1Engine::Digest double_sha1_value = user_authentication.getPasswordDoubleSHA1();
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE); assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
Poco::SHA1Engine engine; Poco::SHA1Engine engine;
@ -112,7 +108,7 @@ void Native41::authenticate(
{ {
password_sha1[i] = digest[i] ^ static_cast<unsigned char>((*auth_response)[i]); password_sha1[i] = digest[i] ^ static_cast<unsigned char>((*auth_response)[i]);
} }
session.setUser(user_name, password_sha1, address); session.authenticate(user_name, password_sha1, address);
} }
#if USE_SSL #if USE_SSL
@ -137,7 +133,7 @@ Sha256Password::Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logg
} }
void Sha256Password::authenticate( void Sha256Password::authenticate(
const String & user_name, std::optional<String> auth_response, Session & session, const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address)
{ {
if (!auth_response) if (!auth_response)
@ -232,7 +228,7 @@ void Sha256Password::authenticate(
password.pop_back(); password.pop_back();
} }
session.setUser(user_name, password, address); session.authenticate(user_name, password, address);
} }
#endif #endif

View File

@ -15,6 +15,7 @@
namespace DB namespace DB
{ {
class Session;
namespace MySQLProtocol namespace MySQLProtocol
{ {
@ -32,7 +33,7 @@ public:
virtual String getAuthPluginData() = 0; virtual String getAuthPluginData() = 0;
virtual void authenticate( virtual void authenticate(
const String & user_name, std::optional<String> auth_response, Session & session, const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0; std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
}; };
@ -49,7 +50,7 @@ public:
String getAuthPluginData() override { return scramble; } String getAuthPluginData() override { return scramble; }
void authenticate( void authenticate(
const String & user_name, std::optional<String> auth_response, Session & session, const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override; std::shared_ptr<PacketEndpoint> packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override;
private: private:
@ -69,7 +70,7 @@ public:
String getAuthPluginData() override { return scramble; } String getAuthPluginData() override { return scramble; }
void authenticate( void authenticate(
const String & user_name, std::optional<String> auth_response, Session & session, const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override; std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override;
private: private:

View File

@ -1,14 +1,11 @@
#pragma once #pragma once
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <functional> #include <functional>
#include <Interpreters/Session.h>
#include <Interpreters/Context.h>
#include <IO/ReadBuffer.h> #include <IO/ReadBuffer.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteBuffer.h> #include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Interpreters/Session.h>
#include <common/logger_useful.h> #include <common/logger_useful.h>
#include <Poco/Format.h> #include <Poco/Format.h>
#include <Poco/RegularExpression.h> #include <Poco/RegularExpression.h>
@ -808,8 +805,9 @@ protected:
Messaging::MessageTransport & mt, Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address) const Poco::Net::SocketAddress & address)
{ {
try { try
session.setUser(user_name, password, address); {
session.authenticate(user_name, password, address);
} }
catch (const Exception &) catch (const Exception &)
{ {
@ -841,7 +839,7 @@ public:
Messaging::MessageTransport & mt, Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address) override const Poco::Net::SocketAddress & address) override
{ {
setPassword(user_name, "", session, mt, address); return setPassword(user_name, "", session, mt, address);
} }
Authentication::Type getType() const override Authentication::Type getType() const override
@ -865,7 +863,7 @@ public:
if (type == Messaging::FrontMessageType::PASSWORD_MESSAGE) if (type == Messaging::FrontMessageType::PASSWORD_MESSAGE)
{ {
std::unique_ptr<Messaging::PasswordMessage> password = mt.receive<Messaging::PasswordMessage>(); std::unique_ptr<Messaging::PasswordMessage> password = mt.receive<Messaging::PasswordMessage>();
setPassword(user_name, password->password, session, mt, address); return setPassword(user_name, password->password, session, mt, address);
} }
else else
throw Exception( throw Exception(
@ -902,16 +900,7 @@ public:
Messaging::MessageTransport & mt, Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address) const Poco::Net::SocketAddress & address)
{ {
Authentication::Type user_auth_type; Authentication::Type user_auth_type = session.getAuthenticationType(user_name);
try
{
user_auth_type = session.getUserAuthentication(user_name).getType();
}
catch (const std::exception & e)
{
session.onLogInFailure(user_name, e);
throw;
}
if (type_to_method.find(user_auth_type) != type_to_method.end()) if (type_to_method.find(user_auth_type) != type_to_method.end())
{ {

View File

@ -255,7 +255,7 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory)
/// We should set user info even for the case when the dictionary is loaded in-process (without TCP communication). /// We should set user info even for the case when the dictionary is loaded in-process (without TCP communication).
if (configuration.is_local) if (configuration.is_local)
{ {
context_copy->setUser(configuration.user, configuration.password, Poco::Net::SocketAddress("127.0.0.1", 0)); context_copy->authenticate(configuration.user, configuration.password, Poco::Net::SocketAddress("127.0.0.1", 0));
context_copy = copyContextAndApplySettings(config_prefix, context_copy, config); context_copy = copyContextAndApplySettings(config_prefix, context_copy, config);
} }

View File

@ -12,6 +12,7 @@
#include <Common/UnicodeBar.h> #include <Common/UnicodeBar.h>
#include <Common/TerminalSize.h> #include <Common/TerminalSize.h>
#include <IO/Operators.h> #include <IO/Operators.h>
#include <IO/Progress.h>
namespace ProfileEvents namespace ProfileEvents

View File

@ -588,48 +588,45 @@ ConfigurationPtr Context::getUsersConfig()
} }
void Context::setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address) void Context::authenticate(const String & name, const String & password, const Poco::Net::SocketAddress & address)
{ {
auto lock = getLock(); authenticate(BasicCredentials(name, password), address);
}
void Context::authenticate(const Credentials & credentials, const Poco::Net::SocketAddress & address)
{
auto authenticated_user_id = getAccessControlManager().login(credentials, address.host());
client_info.current_user = credentials.getUserName(); client_info.current_user = credentials.getUserName();
client_info.current_address = address; client_info.current_address = address;
#if defined(ARCADIA_BUILD) #if defined(ARCADIA_BUILD)
/// This is harmful field that is used only in foreign "Arcadia" build. /// This is harmful field that is used only in foreign "Arcadia" build.
client_info.current_password.clear();
if (const auto * basic_credentials = dynamic_cast<const BasicCredentials *>(&credentials)) if (const auto * basic_credentials = dynamic_cast<const BasicCredentials *>(&credentials))
client_info.current_password = basic_credentials->getPassword(); client_info.current_password = basic_credentials->getPassword();
#endif #endif
/// Find a user with such name and check the credentials. setUser(authenticated_user_id);
auto new_user_id = getAccessControlManager().login(credentials, address.host()); }
auto new_access = getAccessControlManager().getContextAccess(
new_user_id, /* current_roles = */ {}, /* use_default_roles = */ true,
settings, current_database, client_info);
user_id = new_user_id; void Context::setUser(const UUID & user_id_)
access = std::move(new_access); {
auto lock = getLock();
user_id = user_id_;
access = getAccessControlManager().getContextAccess(
user_id_, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info);
auto user = access->getUser(); auto user = access->getUser();
current_roles = std::make_shared<std::vector<UUID>>(user->granted_roles.findGranted(user->default_roles)); current_roles = std::make_shared<std::vector<UUID>>(user->granted_roles.findGranted(user->default_roles));
if (!user->default_database.empty())
setCurrentDatabase(user->default_database);
auto default_profile_info = access->getDefaultProfileInfo(); auto default_profile_info = access->getDefaultProfileInfo();
settings_constraints_and_current_profiles = default_profile_info->getConstraintsAndProfileIDs(); settings_constraints_and_current_profiles = default_profile_info->getConstraintsAndProfileIDs();
applySettingsChanges(default_profile_info->settings); applySettingsChanges(default_profile_info->settings);
}
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address) if (!user->default_database.empty())
{ setCurrentDatabase(user->default_database);
setUser(BasicCredentials(name, password), address);
}
void Context::setUserWithoutCheckingPassword(const String & name, const Poco::Net::SocketAddress & address)
{
setUser(AlwaysAllowCredentials(name), address);
} }
std::shared_ptr<const User> Context::getUser() const std::shared_ptr<const User> Context::getUser() const
@ -637,12 +634,6 @@ std::shared_ptr<const User> Context::getUser() const
return getAccess()->getUser(); return getAccess()->getUser();
} }
void Context::setQuotaKey(String quota_key_)
{
auto lock = getLock();
client_info.quota_key = std::move(quota_key_);
}
String Context::getUserName() const String Context::getUserName() const
{ {
return getAccess()->getUserName(); return getAccess()->getUserName();
@ -655,6 +646,13 @@ std::optional<UUID> Context::getUserID() const
} }
void Context::setQuotaKey(String quota_key_)
{
auto lock = getLock();
client_info.quota_key = std::move(quota_key_);
}
void Context::setCurrentRoles(const std::vector<UUID> & current_roles_) void Context::setCurrentRoles(const std::vector<UUID> & current_roles_)
{ {
auto lock = getLock(); auto lock = getLock();
@ -736,10 +734,13 @@ ASTPtr Context::getRowPolicyCondition(const String & database, const String & ta
void Context::setInitialRowPolicy() void Context::setInitialRowPolicy()
{ {
auto lock = getLock(); auto lock = getLock();
auto initial_user_id = getAccessControlManager().find<User>(client_info.initial_user);
initial_row_policy = nullptr; initial_row_policy = nullptr;
if (initial_user_id) if (client_info.initial_user == client_info.current_user)
initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {}); return;
auto initial_user_id = getAccessControlManager().find<User>(client_info.initial_user);
if (!initial_user_id)
return;
initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {});
} }
@ -1180,6 +1181,9 @@ void Context::setCurrentQueryId(const String & query_id)
} }
client_info.current_query_id = query_id_to_set; client_info.current_query_id = query_id_to_set;
if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
client_info.initial_query_id = client_info.current_query_id;
} }
void Context::killCurrentQuery() void Context::killCurrentQuery()

View File

@ -14,21 +14,16 @@
#include <Common/MultiVersion.h> #include <Common/MultiVersion.h>
#include <Common/OpenTelemetryTraceContext.h> #include <Common/OpenTelemetryTraceContext.h>
#include <Common/RemoteHostFilter.h> #include <Common/RemoteHostFilter.h>
#include <Common/ThreadPool.h>
#include <common/types.h> #include <common/types.h>
#if !defined(ARCADIA_BUILD) #if !defined(ARCADIA_BUILD)
# include "config_core.h" # include "config_core.h"
#endif #endif
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <optional> #include <optional>
#include <thread>
namespace Poco::Net { class IPAddress; } namespace Poco::Net { class IPAddress; }
@ -67,6 +62,7 @@ class ProcessList;
class QueryStatus; class QueryStatus;
class Macros; class Macros;
struct Progress; struct Progress;
struct FileProgress;
class Clusters; class Clusters;
class QueryLog; class QueryLog;
class QueryThreadLog; class QueryThreadLog;
@ -366,23 +362,21 @@ public:
void setUsersConfig(const ConfigurationPtr & config); void setUsersConfig(const ConfigurationPtr & config);
ConfigurationPtr getUsersConfig(); ConfigurationPtr getUsersConfig();
/// Sets the current user, checks the credentials and that the specified host is allowed. /// Sets the current user, checks the credentials and that the specified address is allowed to connect from.
/// Must be called before getClientInfo() can be called. /// The function throws an exception if there is no such user or password is wrong.
void setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address); void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address);
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address); void authenticate(const Credentials & credentials, const Poco::Net::SocketAddress & address);
/// Sets the current user, *does not check the password/credentials and that the specified host is allowed*. /// Sets the current user assuming that he/she is already authenticated.
/// Must be called before getClientInfo. /// WARNING: This function doesn't check password! Don't use until it's necessary!
/// void setUser(const UUID & user_id_);
/// (Used only internally in cluster, if the secret matches)
void setUserWithoutCheckingPassword(const String & name, const Poco::Net::SocketAddress & address);
void setQuotaKey(String quota_key_);
UserPtr getUser() const; UserPtr getUser() const;
String getUserName() const; String getUserName() const;
std::optional<UUID> getUserID() const; std::optional<UUID> getUserID() const;
void setQuotaKey(String quota_key_);
void setCurrentRoles(const std::vector<UUID> & current_roles_); void setCurrentRoles(const std::vector<UUID> & current_roles_);
void setCurrentRolesDefault(); void setCurrentRolesDefault();
boost::container::flat_set<UUID> getCurrentRoles() const; boost::container::flat_set<UUID> getCurrentRoles() const;
@ -590,8 +584,6 @@ public:
std::optional<UInt16> getTCPPortSecure() const; std::optional<UInt16> getTCPPortSecure() const;
std::shared_ptr<NamedSession> acquireNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check);
/// For methods below you may need to acquire the context lock by yourself. /// For methods below you may need to acquire the context lock by yourself.
ContextMutablePtr getQueryContext() const; ContextMutablePtr getQueryContext() const;
@ -602,7 +594,6 @@ public:
bool hasSessionContext() const { return !session_context.expired(); } bool hasSessionContext() const { return !session_context.expired(); }
ContextMutablePtr getGlobalContext() const; ContextMutablePtr getGlobalContext() const;
bool hasGlobalContext() const { return !global_context.expired(); } bool hasGlobalContext() const { return !global_context.expired(); }
bool isGlobalContext() const bool isGlobalContext() const
{ {

View File

@ -1,24 +1,22 @@
#include <Interpreters/Session.h> #include <Interpreters/Session.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/ContextAccess.h>
#include <Access/Credentials.h> #include <Access/Credentials.h>
#include <Access/ContextAccess.h>
#include <Access/User.h> #include <Access/User.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/ThreadPool.h> #include <Common/ThreadPool.h>
#include <Common/setThreadName.h> #include <Common/setThreadName.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <cassert>
#include <chrono>
#include <mutex>
#include <condition_variable>
#include <atomic> #include <atomic>
#include <unordered_map> #include <condition_variable>
#include <deque> #include <deque>
#include <mutex>
#include <unordered_map>
#include <vector> #include <vector>
namespace DB namespace DB
{ {
@ -27,13 +25,13 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR; extern const int LOGICAL_ERROR;
extern const int SESSION_NOT_FOUND; extern const int SESSION_NOT_FOUND;
extern const int SESSION_IS_LOCKED; extern const int SESSION_IS_LOCKED;
extern const int NOT_IMPLEMENTED;
} }
class NamedSessionsStorage; class NamedSessionsStorage;
/// User name and session identifier. Named sessions are local to users. /// User ID and session identifier. Named sessions are local to users.
using NamedSessionKey = std::pair<String, String>; using NamedSessionKey = std::pair<UUID, String>;
/// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests. /// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests.
struct NamedSessionData struct NamedSessionData
@ -75,21 +73,16 @@ public:
} }
/// Find existing session or create a new. /// Find existing session or create a new.
std::shared_ptr<NamedSessionData> acquireSession( std::pair<std::shared_ptr<NamedSessionData>, bool> acquireSession(
const ContextPtr & global_context,
const UUID & user_id,
const String & session_id, const String & session_id,
ContextMutablePtr context,
std::chrono::steady_clock::duration timeout, std::chrono::steady_clock::duration timeout,
bool throw_if_not_found) bool throw_if_not_found)
{ {
std::unique_lock lock(mutex); std::unique_lock lock(mutex);
const auto & client_info = context->getClientInfo(); Key key{user_id, session_id};
const auto & user_name = client_info.current_user;
if (user_name.empty())
throw Exception("Empty user name.", ErrorCodes::LOGICAL_ERROR);
Key key(user_name, session_id);
auto it = sessions.find(key); auto it = sessions.find(key);
if (it == sessions.end()) if (it == sessions.end())
@ -98,22 +91,20 @@ public:
throw Exception("Session not found.", ErrorCodes::SESSION_NOT_FOUND); throw Exception("Session not found.", ErrorCodes::SESSION_NOT_FOUND);
/// Create a new session from current context. /// Create a new session from current context.
auto context = Context::createCopy(global_context);
it = sessions.insert(std::make_pair(key, std::make_shared<NamedSessionData>(key, context, timeout, *this))).first; it = sessions.insert(std::make_pair(key, std::make_shared<NamedSessionData>(key, context, timeout, *this))).first;
const auto & session = it->second;
return {session, true};
} }
else if (it->second->key.first != client_info.current_user) else
{ {
throw Exception("Session belongs to a different user", ErrorCodes::SESSION_IS_LOCKED); /// Use existing session.
const auto & session = it->second;
if (!session.unique())
throw Exception("Session is locked by a concurrent client.", ErrorCodes::SESSION_IS_LOCKED);
return {session, false};
} }
/// Use existing session.
const auto & session = it->second;
if (!session.unique())
throw Exception("Session is locked by a concurrent client.", ErrorCodes::SESSION_IS_LOCKED);
session->context->getClientInfo() = client_info;
return session;
} }
void releaseSession(NamedSessionData & session) void releaseSession(NamedSessionData & session)
@ -229,164 +220,195 @@ void NamedSessionData::release()
std::optional<NamedSessionsStorage> Session::named_sessions = std::nullopt; std::optional<NamedSessionsStorage> Session::named_sessions = std::nullopt;
void Session::enableNamedSessions() void Session::startupNamedSessions()
{ {
named_sessions.emplace(); named_sessions.emplace();
} }
Session::Session(const ContextPtr & context_to_copy, ClientInfo::Interface interface, std::optional<String> default_format) Session::Session(const ContextPtr & global_context_, ClientInfo::Interface interface_)
: session_context(Context::createCopy(context_to_copy)), : global_context(global_context_)
initial_session_context(session_context)
{ {
session_context->makeSessionContext(); prepared_client_info.emplace();
session_context->getClientInfo().interface = interface; prepared_client_info->interface = interface_;
if (default_format)
session_context->setDefaultFormat(*default_format);
} }
Session::Session(Session &&) = default; Session::Session(Session &&) = default;
Session::~Session() Session::~Session()
{ {
releaseNamedSession(); /// Early release a NamedSessionData.
if (access)
{
auto user = access->getUser();
if (user)
onLogOut();
}
}
Authentication Session::getUserAuthentication(const String & user_name) const
{
return session_context->getAccessControlManager().read<User>(user_name)->authentication;
}
void Session::setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address)
{
try
{
session_context->setUser(credentials, address);
// Caching access just in case if context is going to be replaced later (e.g. with context of NamedSessionData)
access = session_context->getAccess();
// Check if this is a not an intercluster session, but the real one.
if (access && access->getUser() && dynamic_cast<const BasicCredentials *>(&credentials))
{
onLogInSuccess();
}
}
catch (const std::exception & e)
{
onLogInFailure(credentials.getUserName(), e);
throw;
}
}
void Session::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address)
{
setUser(BasicCredentials(name, password), address);
}
void Session::onLogInSuccess()
{
}
void Session::onLogInFailure(const String & /* user_name */, const std::exception & /* failure_reason */)
{
}
void Session::onLogOut()
{
}
void Session::promoteToNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check)
{
if (!named_sessions)
throw Exception("Support for named sessions is not enabled", ErrorCodes::NOT_IMPLEMENTED);
auto new_named_session = named_sessions->acquireSession(session_id, session_context, timeout, session_check);
// Must retain previous client info cause otherwise source client address and port,
// and other stuff are reused from previous user of the said session.
const ClientInfo prev_client_info = session_context->getClientInfo();
session_context = new_named_session->context;
session_context->getClientInfo() = prev_client_info;
session_context->makeSessionContext();
named_session.swap(new_named_session);
}
/// Early release a NamedSessionData.
void Session::releaseNamedSession()
{
if (named_session) if (named_session)
{
named_session->release(); named_session->release();
named_session.reset();
}
session_context = initial_session_context;
} }
ContextMutablePtr Session::makeQueryContext(const String & query_id) const Authentication::Type Session::getAuthenticationType(const String & user_name) const
{ {
ContextMutablePtr new_query_context = Context::createCopy(session_context); return global_context->getAccessControlManager().read<User>(user_name)->authentication.getType();
new_query_context->setCurrentQueryId(query_id);
new_query_context->setSessionContext(session_context);
new_query_context->makeQueryContext();
ClientInfo & client_info = new_query_context->getClientInfo();
client_info.initial_user = client_info.current_user;
client_info.initial_query_id = client_info.current_query_id;
client_info.initial_address = client_info.current_address;
return new_query_context;
} }
ContextPtr Session::sessionContext() const Authentication::Digest Session::getPasswordDoubleSHA1(const String & user_name) const
{ {
return session_context; return global_context->getAccessControlManager().read<User>(user_name)->authentication.getPasswordDoubleSHA1();
} }
ContextMutablePtr Session::mutableSessionContext() void Session::authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address)
{ {
return session_context; authenticate(BasicCredentials{user_name, password}, address);
}
void Session::authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_)
{
if (session_context)
throw Exception("If there is a session context it must be created after authentication", ErrorCodes::LOGICAL_ERROR);
user_id = global_context->getAccessControlManager().login(credentials_, address_.host());
prepared_client_info->current_user = credentials_.getUserName();
prepared_client_info->current_address = address_;
#if defined(ARCADIA_BUILD)
/// This is harmful field that is used only in foreign "Arcadia" build.
if (const auto * basic_credentials = dynamic_cast<const BasicCredentials *>(&credentials_))
session_client_info->current_password = basic_credentials->getPassword();
#endif
} }
ClientInfo & Session::getClientInfo() ClientInfo & Session::getClientInfo()
{ {
return session_context->getClientInfo(); return session_context ? session_context->getClientInfo() : *prepared_client_info;
} }
const ClientInfo & Session::getClientInfo() const const ClientInfo & Session::getClientInfo() const
{ {
return session_context->getClientInfo(); return session_context ? session_context->getClientInfo() : *prepared_client_info;
} }
const Settings & Session::getSettings() const ContextMutablePtr Session::makeSessionContext()
{ {
return session_context->getSettingsRef(); if (session_context)
throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR);
if (query_context_created)
throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR);
/// Make a new session context.
ContextMutablePtr new_session_context;
new_session_context = Context::createCopy(global_context);
new_session_context->makeSessionContext();
/// Copy prepared client info to the new session context.
auto & res_client_info = new_session_context->getClientInfo();
res_client_info = std::move(prepared_client_info).value();
prepared_client_info.reset();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id)
new_session_context->setUser(*user_id);
/// Session context is ready.
session_context = new_session_context;
user = session_context->getUser();
return session_context;
} }
void Session::setQuotaKey(const String & quota_key) ContextMutablePtr Session::makeSessionContext(const String & session_id_, std::chrono::steady_clock::duration timeout_, bool session_check_)
{ {
session_context->setQuotaKey(quota_key); if (session_context)
throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR);
if (query_context_created)
throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR);
if (!named_sessions)
throw Exception("Support for named sessions is not enabled", ErrorCodes::LOGICAL_ERROR);
/// Make a new session context OR
/// if the `session_id` and `user_id` were used before then just get a previously created session context.
std::shared_ptr<NamedSessionData> new_named_session;
bool new_named_session_created = false;
std::tie(new_named_session, new_named_session_created)
= named_sessions->acquireSession(global_context, user_id.value_or(UUID{}), session_id_, timeout_, session_check_);
auto new_session_context = new_named_session->context;
new_session_context->makeSessionContext();
/// Copy prepared client info to the session context, no matter it's been just created or not.
/// If we continue using a previously created session context found by session ID
/// it's necessary to replace the client info in it anyway, because it contains actual connection information (client address, etc.)
auto & res_client_info = new_session_context->getClientInfo();
res_client_info = std::move(prepared_client_info).value();
prepared_client_info.reset();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id && !new_session_context->getUser())
new_session_context->setUser(*user_id);
/// Session context is ready.
session_context = new_session_context;
session_id = session_id_;
named_session = new_named_session;
named_session_created = new_named_session_created;
user = session_context->getUser();
return session_context;
} }
String Session::getCurrentDatabase() const ContextMutablePtr Session::makeQueryContext(const ClientInfo & query_client_info) const
{ {
return session_context->getCurrentDatabase(); return makeQueryContextImpl(&query_client_info, nullptr);
} }
void Session::setCurrentDatabase(const String & name) ContextMutablePtr Session::makeQueryContext(ClientInfo && query_client_info) const
{ {
session_context->setCurrentDatabase(name); return makeQueryContextImpl(nullptr, &query_client_info);
}
ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const
{
/// We can create a query context either from a session context or from a global context.
bool from_session_context = static_cast<bool>(session_context);
/// Create a new query context.
ContextMutablePtr query_context = Context::createCopy(from_session_context ? session_context : global_context);
query_context->makeQueryContext();
/// Copy the specified client info to the new query context.
auto & res_client_info = query_context->getClientInfo();
if (client_info_to_move)
res_client_info = std::move(*client_info_to_move);
else if (client_info_to_copy && (client_info_to_copy != &getClientInfo()))
res_client_info = *client_info_to_copy;
/// Copy current user's name and address if it was authenticated after query_client_info was initialized.
if (prepared_client_info && !prepared_client_info->current_user.empty())
{
res_client_info.current_user = prepared_client_info->current_user;
res_client_info.current_address = prepared_client_info->current_address;
#if defined(ARCADIA_BUILD)
res_client_info.current_password = prepared_client_info->current_password;
#endif
}
/// Set parameters of initial query.
if (res_client_info.query_kind == ClientInfo::QueryKind::NO_QUERY)
res_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
if (res_client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{
res_client_info.initial_user = res_client_info.current_user;
res_client_info.initial_address = res_client_info.current_address;
}
/// Sets that row policies from the initial user should be used too.
query_context->setInitialRowPolicy();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id && !query_context->getUser())
query_context->setUser(*user_id);
/// Query context is ready.
query_context_created = true;
user = query_context->getUser();
return query_context;
} }
} }

View File

@ -1,8 +1,9 @@
#pragma once #pragma once
#include <common/types.h> #include <Common/SettingsChanges.h>
#include <Interpreters/Context_fwd.h> #include <Access/Authentication.h>
#include <Interpreters/ClientInfo.h> #include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <chrono> #include <chrono>
#include <memory> #include <memory>
@ -13,77 +14,77 @@ namespace Poco::Net { class SocketAddress; }
namespace DB namespace DB
{ {
class Credentials; class Credentials;
class ContextAccess;
struct Settings;
class Authentication; class Authentication;
struct NamedSessionData; struct NamedSessionData;
class NamedSessionsStorage; class NamedSessionsStorage;
struct User;
using UserPtr = std::shared_ptr<const User>;
/** Represents user-session from the server perspective, /** Represents user-session from the server perspective,
* basically it is just a smaller subset of Context API, simplifies Context management. * basically it is just a smaller subset of Context API, simplifies Context management.
* *
* Holds session context, facilitates acquisition of NamedSession and proper creation of query contexts. * Holds session context, facilitates acquisition of NamedSession and proper creation of query contexts.
* Adds log in, log out and login failure events to the SessionLog.
*/ */
class Session class Session
{ {
static std::optional<NamedSessionsStorage> named_sessions;
public: public:
/// Allow to use named sessions. The thread will be run to cleanup sessions after timeout has expired. /// Allow to use named sessions. The thread will be run to cleanup sessions after timeout has expired.
/// The method must be called at the server startup. /// The method must be called at the server startup.
static void enableNamedSessions(); static void startupNamedSessions();
// static Session makeSessionFromCopyOfContext(const ContextPtr & _context_to_copy); Session(const ContextPtr & global_context_, ClientInfo::Interface interface_);
Session(const ContextPtr & context_to_copy, ClientInfo::Interface interface, std::optional<String> default_format = std::nullopt); Session(Session &&);
virtual ~Session(); ~Session();
Session(const Session &) = delete; Session(const Session &) = delete;
Session& operator=(const Session &) = delete; Session& operator=(const Session &) = delete;
Session(Session &&); /// Provides information about the authentication type of a specified user.
// Session& operator=(Session &&); Authentication::Type getAuthenticationType(const String & user_name) const;
Authentication::Digest getPasswordDoubleSHA1(const String & user_name) const;
Authentication getUserAuthentication(const String & user_name) const; /// Sets the current user, checks the credentials and that the specified address is allowed to connect from.
void setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address); /// The function throws an exception if there is no such user or password is wrong.
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address); void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address);
void authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_);
/// Handle login and logout events.
void onLogInSuccess();
void onLogInFailure(const String & user_name, const std::exception & /* failure_reason */);
void onLogOut();
/** Propmotes current session to a named session.
*
* that is: re-uses or creates NamedSession and then piggybacks on it's context,
* retaining ClientInfo of current session_context.
* Acquired named_session is then released in the destructor.
*/
void promoteToNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check);
/// Early release a NamedSession.
void releaseNamedSession();
ContextMutablePtr makeQueryContext(const String & query_id) const;
ContextPtr sessionContext() const;
ContextMutablePtr mutableSessionContext();
/// Returns a reference to session ClientInfo.
ClientInfo & getClientInfo(); ClientInfo & getClientInfo();
const ClientInfo & getClientInfo() const; const ClientInfo & getClientInfo() const;
const Settings & getSettings() const; /// Makes a session context, can be used one or zero times.
/// The function also assigns an user to this context.
ContextMutablePtr makeSessionContext();
ContextMutablePtr makeSessionContext(const String & session_id_, std::chrono::steady_clock::duration timeout_, bool session_check_);
ContextMutablePtr sessionContext() { return session_context; }
ContextPtr sessionContext() const { return session_context; }
void setQuotaKey(const String & quota_key); /// Makes a query context, can be used multiple times, with or without makeSession() called earlier.
/// The query context will be created from a copy of a session context if it exists, or from a copy of
String getCurrentDatabase() const; /// a global context otherwise. In the latter case the function also assigns an user to this context.
void setCurrentDatabase(const String & name); ContextMutablePtr makeQueryContext() const { return makeQueryContext(getClientInfo()); }
ContextMutablePtr makeQueryContext(const ClientInfo & query_client_info) const;
ContextMutablePtr makeQueryContext(ClientInfo && query_client_info) const;
private: private:
ContextMutablePtr makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const;
const ContextPtr global_context;
/// ClientInfo that will be copied to a session context when it's created.
std::optional<ClientInfo> prepared_client_info;
mutable UserPtr user;
std::optional<UUID> user_id;
ContextMutablePtr session_context; ContextMutablePtr session_context;
// So that Session can be used after forced release of named_session. mutable bool query_context_created = false;
const ContextMutablePtr initial_session_context;
std::shared_ptr<const ContextAccess> access; String session_id;
std::shared_ptr<NamedSessionData> named_session; std::shared_ptr<NamedSessionData> named_session;
bool named_session_created = false;
static std::optional<NamedSessionsStorage> named_sessions;
}; };
} }

View File

@ -11,9 +11,9 @@
#include <DataStreams/PushingToSinkBlockOutputStream.h> #include <DataStreams/PushingToSinkBlockOutputStream.h>
#include <DataTypes/DataTypeFactory.h> #include <DataTypes/DataTypeFactory.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/Session.h>
#include <Interpreters/InternalTextLogsQueue.h> #include <Interpreters/InternalTextLogsQueue.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
#include <Interpreters/Session.h>
#include <IO/ConcatReadBuffer.h> #include <IO/ConcatReadBuffer.h>
#include <IO/ReadBufferFromString.h> #include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
@ -55,7 +55,6 @@ namespace ErrorCodes
extern const int NETWORK_ERROR; extern const int NETWORK_ERROR;
extern const int NO_DATA_TO_INSERT; extern const int NO_DATA_TO_INSERT;
extern const int SUPPORT_IS_DISABLED; extern const int SUPPORT_IS_DISABLED;
extern const int UNKNOWN_DATABASE;
} }
namespace namespace
@ -561,7 +560,7 @@ namespace
IServer & iserver; IServer & iserver;
Poco::Logger * log = nullptr; Poco::Logger * log = nullptr;
std::shared_ptr<Session> session; std::optional<Session> session;
ContextMutablePtr query_context; ContextMutablePtr query_context;
std::optional<CurrentThread::QueryScope> query_scope; std::optional<CurrentThread::QueryScope> query_scope;
String query_text; String query_text;
@ -690,32 +689,20 @@ namespace
password = ""; password = "";
} }
/// Create context.
session = std::make_shared<Session>(iserver.context(), ClientInfo::Interface::GRPC);
/// Authentication. /// Authentication.
session->setUser(user, password, user_address); session.emplace(iserver.context(), ClientInfo::Interface::GRPC);
if (!quota_key.empty()) session->authenticate(user, password, user_address);
session->setQuotaKey(quota_key); session->getClientInfo().quota_key = quota_key;
/// The user could specify session identifier and session timeout. /// The user could specify session identifier and session timeout.
/// It allows to modify settings, create temporary tables and reuse them in subsequent requests. /// It allows to modify settings, create temporary tables and reuse them in subsequent requests.
if (!query_info.session_id().empty()) if (!query_info.session_id().empty())
{ {
session->promoteToNamedSession( session->makeSessionContext(
query_info.session_id(), query_info.session_id(), getSessionTimeout(query_info, iserver.config()), query_info.session_check());
getSessionTimeout(query_info, iserver.config()),
query_info.session_check());
} }
query_context = session->makeQueryContext(query_info.query_id()); query_context = session->makeQueryContext();
query_scope.emplace(query_context);
/// Set client info.
ClientInfo & client_info = query_context->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.initial_user = client_info.current_user;
client_info.initial_query_id = client_info.current_query_id;
client_info.initial_address = client_info.current_address;
/// Prepare settings. /// Prepare settings.
SettingsChanges settings_changes; SettingsChanges settings_changes;
@ -725,11 +712,14 @@ namespace
} }
query_context->checkSettingsConstraints(settings_changes); query_context->checkSettingsConstraints(settings_changes);
query_context->applySettingsChanges(settings_changes); query_context->applySettingsChanges(settings_changes);
const Settings & settings = query_context->getSettingsRef();
query_context->setCurrentQueryId(query_info.query_id());
query_scope.emplace(query_context);
/// Prepare for sending exceptions and logs. /// Prepare for sending exceptions and logs.
send_exception_with_stacktrace = query_context->getSettingsRef().calculate_text_stack_trace; const Settings & settings = query_context->getSettingsRef();
const auto client_logs_level = query_context->getSettingsRef().send_logs_level; send_exception_with_stacktrace = settings.calculate_text_stack_trace;
const auto client_logs_level = settings.send_logs_level;
if (client_logs_level != LogsLevel::none) if (client_logs_level != LogsLevel::none)
{ {
logs_queue = std::make_shared<InternalTextLogsQueue>(); logs_queue = std::make_shared<InternalTextLogsQueue>();
@ -740,14 +730,10 @@ namespace
/// Set the current database if specified. /// Set the current database if specified.
if (!query_info.database().empty()) if (!query_info.database().empty())
{
if (!DatabaseCatalog::instance().isDatabaseExist(query_info.database()))
throw Exception("Database " + query_info.database() + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE);
query_context->setCurrentDatabase(query_info.database()); query_context->setCurrentDatabase(query_info.database());
}
/// The interactive delay will be used to show progress. /// The interactive delay will be used to show progress.
interactive_delay = query_context->getSettingsRef().interactive_delay; interactive_delay = settings.interactive_delay;
query_context->setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); }); query_context->setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); });
/// Parse the query. /// Parse the query.

View File

@ -19,9 +19,9 @@
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <IO/copyData.h> #include <IO/copyData.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/Session.h>
#include <Interpreters/QueryParameterVisitor.h> #include <Interpreters/QueryParameterVisitor.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
#include <Interpreters/Session.h>
#include <Server/HTTPHandlerFactory.h> #include <Server/HTTPHandlerFactory.h>
#include <Server/HTTPHandlerRequestFilter.h> #include <Server/HTTPHandlerRequestFilter.h>
#include <Server/IServer.h> #include <Server/IServer.h>
@ -262,6 +262,7 @@ void HTTPHandler::pushDelayedResults(Output & used_output)
HTTPHandler::HTTPHandler(IServer & server_, const std::string & name) HTTPHandler::HTTPHandler(IServer & server_, const std::string & name)
: server(server_) : server(server_)
, log(&Poco::Logger::get(name)) , log(&Poco::Logger::get(name))
, default_settings(server.context()->getSettingsRef())
{ {
server_display_name = server.config().getString("display_name", getFQDNOrHostName()); server_display_name = server.config().getString("display_name", getFQDNOrHostName());
} }
@ -269,10 +270,7 @@ HTTPHandler::HTTPHandler(IServer & server_, const std::string & name)
/// We need d-tor to be present in this translation unit to make it play well with some /// We need d-tor to be present in this translation unit to make it play well with some
/// forward decls in the header. Other than that, the default d-tor would be OK. /// forward decls in the header. Other than that, the default d-tor would be OK.
HTTPHandler::~HTTPHandler() HTTPHandler::~HTTPHandler() = default;
{
(void)this;
}
bool HTTPHandler::authenticateUser( bool HTTPHandler::authenticateUser(
@ -352,7 +350,7 @@ bool HTTPHandler::authenticateUser(
else else
{ {
if (!request_credentials) if (!request_credentials)
request_credentials = request_session->sessionContext()->makeGSSAcceptorContext(); request_credentials = server.context()->makeGSSAcceptorContext();
auto * gss_acceptor_context = dynamic_cast<GSSAcceptorContext *>(request_credentials.get()); auto * gss_acceptor_context = dynamic_cast<GSSAcceptorContext *>(request_credentials.get());
if (!gss_acceptor_context) if (!gss_acceptor_context)
@ -378,9 +376,7 @@ bool HTTPHandler::authenticateUser(
} }
/// Set client info. It will be used for quota accounting parameters in 'setUser' method. /// Set client info. It will be used for quota accounting parameters in 'setUser' method.
ClientInfo & client_info = session->getClientInfo();
ClientInfo & client_info = request_session->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
if (request.getMethod() == HTTPServerRequest::HTTP_GET) if (request.getMethod() == HTTPServerRequest::HTTP_GET)
@ -392,10 +388,11 @@ bool HTTPHandler::authenticateUser(
client_info.http_user_agent = request.get("User-Agent", ""); client_info.http_user_agent = request.get("User-Agent", "");
client_info.http_referer = request.get("Referer", ""); client_info.http_referer = request.get("Referer", "");
client_info.forwarded_for = request.get("X-Forwarded-For", ""); client_info.forwarded_for = request.get("X-Forwarded-For", "");
client_info.quota_key = quota_key;
try try
{ {
request_session->setUser(*request_credentials, request.clientAddress()); session->authenticate(*request_credentials, request.clientAddress());
} }
catch (const Authentication::Require<BasicCredentials> & required_credentials) catch (const Authentication::Require<BasicCredentials> & required_credentials)
{ {
@ -412,7 +409,7 @@ bool HTTPHandler::authenticateUser(
} }
catch (const Authentication::Require<GSSAcceptorContext> & required_credentials) catch (const Authentication::Require<GSSAcceptorContext> & required_credentials)
{ {
request_credentials = request_session->sessionContext()->makeGSSAcceptorContext(); request_credentials = server.context()->makeGSSAcceptorContext();
if (required_credentials.getRealm().empty()) if (required_credentials.getRealm().empty())
response.set("WWW-Authenticate", "Negotiate"); response.set("WWW-Authenticate", "Negotiate");
@ -425,14 +422,6 @@ bool HTTPHandler::authenticateUser(
} }
request_credentials.reset(); request_credentials.reset();
if (!quota_key.empty())
request_session->setQuotaKey(quota_key);
/// Query sent through HTTP interface is initial.
client_info.initial_user = client_info.current_user;
client_info.initial_address = client_info.current_address;
return true; return true;
} }
@ -463,20 +452,16 @@ void HTTPHandler::processQuery(
session_id = params.get("session_id"); session_id = params.get("session_id");
session_timeout = parseSessionTimeout(config, params); session_timeout = parseSessionTimeout(config, params);
std::string session_check = params.get("session_check", ""); std::string session_check = params.get("session_check", "");
request_session->promoteToNamedSession(session_id, session_timeout, session_check == "1"); session->makeSessionContext(session_id, session_timeout, session_check == "1");
} }
SCOPE_EXIT({
request_session->releaseNamedSession();
});
// Parse the OpenTelemetry traceparent header. // Parse the OpenTelemetry traceparent header.
// Disable in Arcadia -- it interferes with the // Disable in Arcadia -- it interferes with the
// test_clickhouse.TestTracing.test_tracing_via_http_proxy[traceparent] test. // test_clickhouse.TestTracing.test_tracing_via_http_proxy[traceparent] test.
ClientInfo client_info = session->getClientInfo();
#if !defined(ARCADIA_BUILD) #if !defined(ARCADIA_BUILD)
if (request.has("traceparent")) if (request.has("traceparent"))
{ {
ClientInfo & client_info = request_session->getClientInfo();
std::string opentelemetry_traceparent = request.get("traceparent"); std::string opentelemetry_traceparent = request.get("traceparent");
std::string error; std::string error;
if (!client_info.client_trace_context.parseTraceparentHeader( if (!client_info.client_trace_context.parseTraceparentHeader(
@ -486,16 +471,11 @@ void HTTPHandler::processQuery(
"Failed to parse OpenTelemetry traceparent header '{}': {}", "Failed to parse OpenTelemetry traceparent header '{}': {}",
opentelemetry_traceparent, error); opentelemetry_traceparent, error);
} }
client_info.client_trace_context.tracestate = request.get("tracestate", ""); client_info.client_trace_context.tracestate = request.get("tracestate", "");
} }
#endif #endif
// Set the query id supplied by the user, if any, and also update the OpenTelemetry fields. auto context = session->makeQueryContext(std::move(client_info));
auto context = request_session->makeQueryContext(params.get("query_id", request.get("X-ClickHouse-Query-Id", "")));
ClientInfo & client_info = context->getClientInfo();
client_info.initial_query_id = client_info.current_query_id;
/// The client can pass a HTTP header indicating supported compression method (gzip or deflate). /// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
String http_response_compression_methods = request.get("Accept-Encoding", ""); String http_response_compression_methods = request.get("Accept-Encoding", "");
@ -560,7 +540,7 @@ void HTTPHandler::processQuery(
if (buffer_until_eof) if (buffer_until_eof)
{ {
const std::string tmp_path(context->getTemporaryVolume()->getDisk()->getPath()); const std::string tmp_path(server.context()->getTemporaryVolume()->getDisk()->getPath());
const std::string tmp_path_template(tmp_path + "http_buffers/"); const std::string tmp_path_template(tmp_path + "http_buffers/");
auto create_tmp_disk_buffer = [tmp_path_template] (const WriteBufferPtr &) auto create_tmp_disk_buffer = [tmp_path_template] (const WriteBufferPtr &)
@ -706,6 +686,9 @@ void HTTPHandler::processQuery(
context->checkSettingsConstraints(settings_changes); context->checkSettingsConstraints(settings_changes);
context->applySettingsChanges(settings_changes); context->applySettingsChanges(settings_changes);
// Set the query id supplied by the user, if any, and also update the OpenTelemetry fields.
context->setCurrentQueryId(params.get("query_id", request.get("X-ClickHouse-Query-Id", "")));
const auto & query = getQuery(request, params, context); const auto & query = getQuery(request, params, context);
std::unique_ptr<ReadBuffer> in_param = std::make_unique<ReadBufferFromString>(query); std::unique_ptr<ReadBuffer> in_param = std::make_unique<ReadBufferFromString>(query);
in = has_external_data ? std::move(in_param) : std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed); in = has_external_data ? std::move(in_param) : std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
@ -856,23 +839,10 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
setThreadName("HTTPHandler"); setThreadName("HTTPHandler");
ThreadStatus thread_status; ThreadStatus thread_status;
SCOPE_EXIT({ session = std::make_unique<Session>(server.context(), ClientInfo::Interface::HTTP);
// If there is no request_credentials instance waiting for the next round, then the request is processed, SCOPE_EXIT({ session.reset(); });
// so no need to preserve request_session either.
// Needs to be performed with respect to the other destructors in the scope though.
if (!request_credentials)
request_session.reset();
});
if (!request_session)
{
// Context should be initialized before anything, for correct memory accounting.
request_session = std::make_shared<Session>(server.context(), ClientInfo::Interface::HTTP);
request_credentials.reset();
}
/// Cannot be set here, since query_id is unknown.
std::optional<CurrentThread::QueryScope> query_scope; std::optional<CurrentThread::QueryScope> query_scope;
Output used_output; Output used_output;
/// In case of exception, send stack trace to client. /// In case of exception, send stack trace to client.
@ -886,7 +856,7 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
if (request.getVersion() == HTTPServerRequest::HTTP_1_1) if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true); response.setChunkedTransferEncoding(true);
HTMLForm params(request_session->getSettings(), request); HTMLForm params(default_settings, request);
with_stacktrace = params.getParsed<bool>("stacktrace", false); with_stacktrace = params.getParsed<bool>("stacktrace", false);
/// FIXME: maybe this check is already unnecessary. /// FIXME: maybe this check is already unnecessary.

View File

@ -21,6 +21,7 @@ namespace DB
class Session; class Session;
class Credentials; class Credentials;
class IServer; class IServer;
struct Settings;
class WriteBufferFromHTTPServerResponse; class WriteBufferFromHTTPServerResponse;
using CompiledRegexPtr = std::shared_ptr<const re2::RE2>; using CompiledRegexPtr = std::shared_ptr<const re2::RE2>;
@ -72,15 +73,22 @@ private:
CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection}; CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection};
// The request_session and the request_credentials instances may outlive a single request/response loop. /// Reference to the immutable settings in the global context.
/// Those settings are used only to extract a http request's parameters.
/// See settings http_max_fields, http_max_field_name_size, http_max_field_value_size in HTMLForm.
const Settings & default_settings;
// session is reset at the end of each request/response.
std::unique_ptr<Session> session;
// The request_credential instance may outlive a single request/response loop.
// This happens only when the authentication mechanism requires more than a single request/response exchange (e.g., SPNEGO). // This happens only when the authentication mechanism requires more than a single request/response exchange (e.g., SPNEGO).
std::shared_ptr<Session> request_session;
std::unique_ptr<Credentials> request_credentials; std::unique_ptr<Credentials> request_credentials;
// Returns true when the user successfully authenticated, // Returns true when the user successfully authenticated,
// the request_session instance will be configured accordingly, and the request_credentials instance will be dropped. // the session instance will be configured accordingly, and the request_credentials instance will be dropped.
// Returns false when the user is not authenticated yet, and the 'Negotiate' response is sent, // Returns false when the user is not authenticated yet, and the 'Negotiate' response is sent,
// the request_session and request_credentials instances are preserved. // the session and request_credentials instances are preserved.
// Throws an exception if authentication failed. // Throws an exception if authentication failed.
bool authenticateUser( bool authenticateUser(
HTTPServerRequest & request, HTTPServerRequest & request,

View File

@ -3,11 +3,11 @@
#include <limits> #include <limits>
#include <Common/NetException.h> #include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h> #include <Common/OpenSSLHelpers.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h> #include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h> #include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h> #include <Core/MySQL/PacketsProtocolText.h>
#include <Core/NamesAndTypes.h> #include <Core/NamesAndTypes.h>
#include <DataStreams/copyData.h>
#include <Interpreters/Session.h> #include <Interpreters/Session.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
#include <IO/copyData.h> #include <IO/copyData.h>
@ -19,9 +19,8 @@
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <Storages/IStorage.h> #include <Storages/IStorage.h>
#include <regex> #include <regex>
#include <Access/User.h>
#include <Access/AccessControlManager.h>
#include <Common/setThreadName.h> #include <Common/setThreadName.h>
#include <Core/MySQL/Authentication.h>
#if !defined(ARCADIA_BUILD) #if !defined(ARCADIA_BUILD)
# include <Common/config_version.h> # include <Common/config_version.h>
@ -88,12 +87,10 @@ void MySQLHandler::run()
setThreadName("MySQLHandler"); setThreadName("MySQLHandler");
ThreadStatus thread_status; ThreadStatus thread_status;
session = std::make_shared<Session>(server.context(), ClientInfo::Interface::MYSQL, "MySQLWire"); session = std::make_unique<Session>(server.context(), ClientInfo::Interface::MYSQL);
auto & session_client_info = session->getClientInfo(); SCOPE_EXIT({ session.reset(); });
session_client_info.current_address = socket().peerAddress(); session->getClientInfo().connection_id = connection_id;
session_client_info.connection_id = connection_id;
session_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
in = std::make_shared<ReadBufferFromPocoSocket>(socket()); in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket()); out = std::make_shared<WriteBufferFromPocoSocket>(socket());
@ -127,12 +124,12 @@ void MySQLHandler::run()
authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response); authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response);
session_client_info.initial_user = handshake_response.username;
try try
{ {
session->makeSessionContext();
session->sessionContext()->setDefaultFormat("MySQLWire");
if (!handshake_response.database.empty()) if (!handshake_response.database.empty())
session->setCurrentDatabase(handshake_response.database); session->sessionContext()->setCurrentDatabase(handshake_response.database);
} }
catch (const Exception & exc) catch (const Exception & exc)
{ {
@ -246,26 +243,23 @@ void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResp
void MySQLHandler::authenticate(const String & user_name, const String & auth_plugin_name, const String & initial_auth_response) void MySQLHandler::authenticate(const String & user_name, const String & auth_plugin_name, const String & initial_auth_response)
{ {
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
DB::Authentication::Type user_auth_type;
try try
{ {
user_auth_type = session->getUserAuthentication(user_name).getType(); // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
if (session->getAuthenticationType(user_name) == DB::Authentication::SHA256_PASSWORD)
{
authPluginSSL();
}
std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress());
} }
catch (const std::exception & e) catch (const Exception & exc)
{ {
session->onLogInFailure(user_name, e); LOG_ERROR(log, "Authentication for user {} failed.", user_name);
packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
throw; throw;
} }
if (user_auth_type == DB::Authentication::SHA256_PASSWORD)
{
authPluginSSL();
}
std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
auth_plugin->authenticate(user_name, auth_response, *session, packet_endpoint, secure_connection, socket().peerAddress());
LOG_DEBUG(log, "Authentication for user {} succeeded.", user_name); LOG_DEBUG(log, "Authentication for user {} succeeded.", user_name);
} }
@ -274,7 +268,7 @@ void MySQLHandler::comInitDB(ReadBuffer & payload)
String database; String database;
readStringUntilEOF(database, payload); readStringUntilEOF(database, payload);
LOG_DEBUG(log, "Setting current database to {}", database); LOG_DEBUG(log, "Setting current database to {}", database);
session->setCurrentDatabase(database); session->sessionContext()->setCurrentDatabase(database);
packet_endpoint->sendPacket(OKPacket(0, client_capabilities, 0, 0, 1), true); packet_endpoint->sendPacket(OKPacket(0, client_capabilities, 0, 0, 1), true);
} }
@ -331,7 +325,9 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
ReadBufferFromString replacement(replacement_query); ReadBufferFromString replacement(replacement_query);
auto query_context = session->makeQueryContext(Poco::format("mysql:%lu", connection_id)); auto query_context = session->makeQueryContext();
query_context->setCurrentQueryId(Poco::format("mysql:%lu", connection_id));
CurrentThread::QueryScope query_scope{query_context};
std::atomic<size_t> affected_rows {0}; std::atomic<size_t> affected_rows {0};
auto prev = query_context->getProgressCallback(); auto prev = query_context->getProgressCallback();
@ -343,8 +339,6 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
affected_rows += progress.written_rows; affected_rows += progress.written_rows;
}); });
CurrentThread::QueryScope query_scope{query_context};
FormatSettings format_settings; FormatSettings format_settings;
format_settings.mysql_wire.client_capabilities = client_capabilities; format_settings.mysql_wire.client_capabilities = client_capabilities;
format_settings.mysql_wire.max_packet_size = max_packet_size; format_settings.mysql_wire.max_packet_size = max_packet_size;

View File

@ -63,7 +63,7 @@ protected:
uint8_t sequence_id = 0; uint8_t sequence_id = 0;
MySQLProtocol::PacketEndpointPtr packet_endpoint; MySQLProtocol::PacketEndpointPtr packet_endpoint;
std::shared_ptr<Session> session; std::unique_ptr<Session> session;
using ReplacementFn = std::function<String(const String & query)>; using ReplacementFn = std::function<String(const String & query)>;
using Replacements = std::unordered_map<std::string, ReplacementFn>; using Replacements = std::unordered_map<std::string, ReplacementFn>;

View File

@ -2,8 +2,8 @@
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromString.h> #include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromPocoSocket.h> #include <IO/WriteBufferFromPocoSocket.h>
#include <Interpreters/Context.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
#include <Interpreters/Session.h>
#include "PostgreSQLHandler.h" #include "PostgreSQLHandler.h"
#include <Parsers/parseQuery.h> #include <Parsers/parseQuery.h>
#include <Common/setThreadName.h> #include <Common/setThreadName.h>
@ -53,14 +53,12 @@ void PostgreSQLHandler::run()
setThreadName("PostgresHandler"); setThreadName("PostgresHandler");
ThreadStatus thread_status; ThreadStatus thread_status;
Session session(server.context(), ClientInfo::Interface::POSTGRESQL, "PostgreSQLWire"); session = std::make_unique<Session>(server.context(), ClientInfo::Interface::POSTGRESQL);
auto & session_client_info = session.getClientInfo(); SCOPE_EXIT({ session.reset(); });
session_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
try try
{ {
if (!startup(session)) if (!startup())
return; return;
while (true) while (true)
@ -71,7 +69,7 @@ void PostgreSQLHandler::run()
switch (message_type) switch (message_type)
{ {
case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY: case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY:
processQuery(session); processQuery();
break; break;
case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE: case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE:
LOG_DEBUG(log, "Client closed the connection"); LOG_DEBUG(log, "Client closed the connection");
@ -110,7 +108,7 @@ void PostgreSQLHandler::run()
} }
bool PostgreSQLHandler::startup(Session & session) bool PostgreSQLHandler::startup()
{ {
Int32 payload_size; Int32 payload_size;
Int32 info; Int32 info;
@ -119,17 +117,20 @@ bool PostgreSQLHandler::startup(Session & session)
if (static_cast<PostgreSQLProtocol::Messaging::FrontMessageType>(info) == PostgreSQLProtocol::Messaging::FrontMessageType::CANCEL_REQUEST) if (static_cast<PostgreSQLProtocol::Messaging::FrontMessageType>(info) == PostgreSQLProtocol::Messaging::FrontMessageType::CANCEL_REQUEST)
{ {
LOG_DEBUG(log, "Client issued request canceling"); LOG_DEBUG(log, "Client issued request canceling");
cancelRequest(session); cancelRequest();
return false; return false;
} }
std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> start_up_msg = receiveStartupMessage(payload_size); std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> start_up_msg = receiveStartupMessage(payload_size);
authentication_manager.authenticate(start_up_msg->user, session, *message_transport, socket().peerAddress()); const auto & user_name = start_up_msg->user;
authentication_manager.authenticate(user_name, *session, *message_transport, socket().peerAddress());
try try
{ {
session->makeSessionContext();
session->sessionContext()->setDefaultFormat("PostgreSQLWire");
if (!start_up_msg->database.empty()) if (!start_up_msg->database.empty())
session.setCurrentDatabase(start_up_msg->database); session->sessionContext()->setCurrentDatabase(start_up_msg->database);
} }
catch (const Exception & exc) catch (const Exception & exc)
{ {
@ -207,18 +208,16 @@ void PostgreSQLHandler::sendParameterStatusData(PostgreSQLProtocol::Messaging::S
message_transport->flush(); message_transport->flush();
} }
void PostgreSQLHandler::cancelRequest(Session & session) void PostgreSQLHandler::cancelRequest()
{ {
// TODO (nemkov): maybe run cancellation query with session context?
auto query_context = session.makeQueryContext(std::string{});
query_context->setDefaultFormat("Null");
std::unique_ptr<PostgreSQLProtocol::Messaging::CancelRequest> msg = std::unique_ptr<PostgreSQLProtocol::Messaging::CancelRequest> msg =
message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::CancelRequest>(8); message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::CancelRequest>(8);
String query = Poco::format("KILL QUERY WHERE query_id = 'postgres:%d:%d'", msg->process_id, msg->secret_key); String query = Poco::format("KILL QUERY WHERE query_id = 'postgres:%d:%d'", msg->process_id, msg->secret_key);
ReadBufferFromString replacement(query); ReadBufferFromString replacement(query);
auto query_context = session->makeQueryContext();
query_context->setCurrentQueryId("");
executeQuery(replacement, *out, true, query_context, {}); executeQuery(replacement, *out, true, query_context, {});
} }
@ -242,7 +241,7 @@ inline std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> PostgreSQL
return message; return message;
} }
void PostgreSQLHandler::processQuery(Session & session) void PostgreSQLHandler::processQuery()
{ {
try try
{ {
@ -265,7 +264,7 @@ void PostgreSQLHandler::processQuery(Session & session)
return; return;
} }
const auto & settings = session.getSettings(); const auto & settings = session->sessionContext()->getSettingsRef();
std::vector<String> queries; std::vector<String> queries;
auto parse_res = splitMultipartQuery(query->query, queries, settings.max_query_size, settings.max_parser_depth); auto parse_res = splitMultipartQuery(query->query, queries, settings.max_query_size, settings.max_parser_depth);
if (!parse_res.second) if (!parse_res.second)
@ -278,7 +277,8 @@ void PostgreSQLHandler::processQuery(Session & session)
for (const auto & spl_query : queries) for (const auto & spl_query : queries)
{ {
secret_key = dis(gen); secret_key = dis(gen);
auto query_context = session.makeQueryContext(Poco::format("postgres:%d:%d", connection_id, secret_key)); auto query_context = session->makeQueryContext();
query_context->setCurrentQueryId(Poco::format("postgres:%d:%d", connection_id, secret_key));
CurrentThread::QueryScope query_scope{query_context}; CurrentThread::QueryScope query_scope{query_context};
ReadBufferFromString read_buf(spl_query); ReadBufferFromString read_buf(spl_query);

View File

@ -39,6 +39,7 @@ private:
Poco::Logger * log = &Poco::Logger::get("PostgreSQLHandler"); Poco::Logger * log = &Poco::Logger::get("PostgreSQLHandler");
IServer & server; IServer & server;
std::unique_ptr<Session> session;
bool ssl_enabled = false; bool ssl_enabled = false;
Int32 connection_id = 0; Int32 connection_id = 0;
Int32 secret_key = 0; Int32 secret_key = 0;
@ -57,7 +58,7 @@ private:
void changeIO(Poco::Net::StreamSocket & socket); void changeIO(Poco::Net::StreamSocket & socket);
bool startup(Session & session); bool startup();
void establishSecureConnection(Int32 & payload_size, Int32 & info); void establishSecureConnection(Int32 & payload_size, Int32 & info);
@ -65,11 +66,11 @@ private:
void sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message); void sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message);
void cancelRequest(Session & session); void cancelRequest();
std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> receiveStartupMessage(int payload_size); std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> receiveStartupMessage(int payload_size);
void processQuery(DB::Session & session); void processQuery();
static bool isEmptyQuery(const String & query); static bool isEmptyQuery(const String & query);
}; };

View File

@ -20,16 +20,16 @@
#include <DataStreams/NativeBlockInputStream.h> #include <DataStreams/NativeBlockInputStream.h>
#include <DataStreams/NativeBlockOutputStream.h> #include <DataStreams/NativeBlockOutputStream.h>
#include <DataStreams/PushingToSinkBlockOutputStream.h> #include <DataStreams/PushingToSinkBlockOutputStream.h>
#include <Interpreters/Context.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
#include <Interpreters/TablesStatus.h> #include <Interpreters/TablesStatus.h>
#include <Interpreters/InternalTextLogsQueue.h> #include <Interpreters/InternalTextLogsQueue.h>
#include <Interpreters/Session.h>
#include <Interpreters/OpenTelemetrySpanLog.h> #include <Interpreters/OpenTelemetrySpanLog.h>
#include <Interpreters/Session.h>
#include <Storages/StorageReplicatedMergeTree.h> #include <Storages/StorageReplicatedMergeTree.h>
#include <Storages/MergeTree/MergeTreeDataPartUUID.h> #include <Storages/MergeTree/MergeTreeDataPartUUID.h>
#include <Storages/StorageS3Cluster.h> #include <Storages/StorageS3Cluster.h>
#include <Core/ExternalTable.h> #include <Core/ExternalTable.h>
#include <Access/Credentials.h>
#include <Storages/ColumnDefault.h> #include <Storages/ColumnDefault.h>
#include <DataTypes/DataTypeLowCardinality.h> #include <DataTypes/DataTypeLowCardinality.h>
#include <Compression/CompressionFactory.h> #include <Compression/CompressionFactory.h>
@ -75,7 +75,6 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR; extern const int LOGICAL_ERROR;
extern const int ATTEMPT_TO_READ_AFTER_EOF; extern const int ATTEMPT_TO_READ_AFTER_EOF;
extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT; extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT;
extern const int UNKNOWN_DATABASE;
extern const int UNKNOWN_EXCEPTION; extern const int UNKNOWN_EXCEPTION;
extern const int UNKNOWN_PACKET_FROM_CLIENT; extern const int UNKNOWN_PACKET_FROM_CLIENT;
extern const int POCO_EXCEPTION; extern const int POCO_EXCEPTION;
@ -90,7 +89,6 @@ TCPHandler::TCPHandler(IServer & server_, const Poco::Net::StreamSocket & socket
, server(server_) , server(server_)
, parse_proxy_protocol(parse_proxy_protocol_) , parse_proxy_protocol(parse_proxy_protocol_)
, log(&Poco::Logger::get("TCPHandler")) , log(&Poco::Logger::get("TCPHandler"))
, query_context(Context::createCopy(server.context()))
, server_display_name(std::move(server_display_name_)) , server_display_name(std::move(server_display_name_))
{ {
} }
@ -115,16 +113,10 @@ void TCPHandler::runImpl()
ThreadStatus thread_status; ThreadStatus thread_status;
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::TCP); session = std::make_unique<Session>(server.context(), ClientInfo::Interface::TCP);
const auto session_context = session->sessionContext(); extractConnectionSettingsFromContext(server.context());
/// These timeouts can be changed after receiving query. socket().setReceiveTimeout(receive_timeout);
const auto & settings = session->getSettings(); socket().setSendTimeout(send_timeout);
auto global_receive_timeout = settings.receive_timeout;
auto global_send_timeout = settings.send_timeout;
socket().setReceiveTimeout(global_receive_timeout);
socket().setSendTimeout(global_send_timeout);
socket().setNoDelay(true); socket().setNoDelay(true);
in = std::make_shared<ReadBufferFromPocoSocket>(socket()); in = std::make_shared<ReadBufferFromPocoSocket>(socket());
@ -162,33 +154,27 @@ void TCPHandler::runImpl()
try try
{ {
/// We try to send error information to the client. /// We try to send error information to the client.
sendException(e, session->getSettings().calculate_text_stack_trace); sendException(e, send_exception_with_stack_trace);
} }
catch (...) {} catch (...) {}
throw; throw;
} }
/// When connecting, the default database can be specified.
if (!default_database.empty())
{
if (!DatabaseCatalog::instance().isDatabaseExist(default_database))
{
Exception e("Database " + backQuote(default_database) + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE);
LOG_ERROR(log, getExceptionMessage(e, true));
sendException(e, settings.calculate_text_stack_trace);
return;
}
session->setCurrentDatabase(default_database);
}
UInt64 idle_connection_timeout = settings.idle_connection_timeout;
UInt64 poll_interval = settings.poll_interval;
sendHello(); sendHello();
session->mutableSessionContext()->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); }); if (!is_interserver_mode) /// In interserver mode queries are executed without a session context.
{
session->makeSessionContext();
/// If session created, then settings in session context has been updated.
/// So it's better to update the connection settings for flexibility.
extractConnectionSettingsFromContext(session->sessionContext());
/// When connecting, the default database could be specified.
if (!default_database.empty())
session->sessionContext()->setCurrentDatabase(default_database);
}
while (true) while (true)
{ {
@ -210,10 +196,6 @@ void TCPHandler::runImpl()
if (server.isCancelled() || in->eof()) if (server.isCancelled() || in->eof())
break; break;
/// Set context of request.
/// TODO (nemkov): create query later in receiveQuery
query_context = session->makeQueryContext(std::string{}); // proper query_id is set later in receiveQuery
Stopwatch watch; Stopwatch watch;
state.reset(); state.reset();
@ -226,8 +208,6 @@ void TCPHandler::runImpl()
std::optional<DB::Exception> exception; std::optional<DB::Exception> exception;
bool network_error = false; bool network_error = false;
bool send_exception_with_stack_trace = true;
try try
{ {
/// If a user passed query-local timeouts, reset socket to initial state at the end of the query /// If a user passed query-local timeouts, reset socket to initial state at the end of the query
@ -240,23 +220,22 @@ void TCPHandler::runImpl()
if (!receivePacket()) if (!receivePacket())
continue; continue;
/** If Query received, then settings in query_context has been updated
* So, update some other connection settings, for flexibility.
*/
{
const Settings & query_settings = query_context->getSettingsRef();
idle_connection_timeout = query_settings.idle_connection_timeout;
poll_interval = query_settings.poll_interval;
}
/** If part_uuids got received in previous packet, trying to read again. /** If part_uuids got received in previous packet, trying to read again.
*/ */
if (state.empty() && state.part_uuids && !receivePacket()) if (state.empty() && state.part_uuids_to_ignore && !receivePacket())
continue; continue;
query_scope.emplace(query_context); query_scope.emplace(query_context);
send_exception_with_stack_trace = query_context->getSettingsRef().calculate_text_stack_trace; /// If query received, then settings in query_context has been updated.
/// So it's better to update the connection settings for flexibility.
extractConnectionSettingsFromContext(query_context);
/// Sync timeouts on client and server during current query to avoid dangling queries on server
/// NOTE: We use send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter),
/// because send_timeout is client-side setting which has opposite meaning on the server side.
/// NOTE: these settings are applied only for current connection (not for distributed tables' connections)
state.timeout_setter = std::make_unique<TimeoutSetter>(socket(), receive_timeout, send_timeout);
/// Should we send internal logs to client? /// Should we send internal logs to client?
const auto client_logs_level = query_context->getSettingsRef().send_logs_level; const auto client_logs_level = query_context->getSettingsRef().send_logs_level;
@ -269,20 +248,18 @@ void TCPHandler::runImpl()
CurrentThread::setFatalErrorCallback([this]{ sendLogs(); }); CurrentThread::setFatalErrorCallback([this]{ sendLogs(); });
} }
query_context->setExternalTablesInitializer([&settings, this] (ContextPtr context) query_context->setExternalTablesInitializer([this] (ContextPtr context)
{ {
if (context != query_context) if (context != query_context)
throw Exception("Unexpected context in external tables initializer", ErrorCodes::LOGICAL_ERROR); throw Exception("Unexpected context in external tables initializer", ErrorCodes::LOGICAL_ERROR);
/// Get blocks of temporary tables /// Get blocks of temporary tables
readData(settings); readData();
/// Reset the input stream, as we received an empty block while receiving external table data. /// Reset the input stream, as we received an empty block while receiving external table data.
/// So, the stream has been marked as cancelled and we can't read from it anymore. /// So, the stream has been marked as cancelled and we can't read from it anymore.
state.block_in.reset(); state.block_in.reset();
state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker. state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker.
state.temporary_tables_read = true;
}); });
/// Send structure of columns to client for function input() /// Send structure of columns to client for function input()
@ -306,15 +283,12 @@ void TCPHandler::runImpl()
sendData(state.input_header); sendData(state.input_header);
}); });
query_context->setInputBlocksReaderCallback([&settings, this] (ContextPtr context) -> Block query_context->setInputBlocksReaderCallback([this] (ContextPtr context) -> Block
{ {
if (context != query_context) if (context != query_context)
throw Exception("Unexpected context in InputBlocksReader", ErrorCodes::LOGICAL_ERROR); throw Exception("Unexpected context in InputBlocksReader", ErrorCodes::LOGICAL_ERROR);
size_t poll_interval_ms; if (!readDataNext())
int receive_timeout;
std::tie(poll_interval_ms, receive_timeout) = getReadTimeouts(settings);
if (!readDataNext(poll_interval_ms, receive_timeout))
{ {
state.block_in.reset(); state.block_in.reset();
state.maybe_compressed_in.reset(); state.maybe_compressed_in.reset();
@ -337,15 +311,13 @@ void TCPHandler::runImpl()
/// Processing Query /// Processing Query
state.io = executeQuery(state.query, query_context, false, state.stage, may_have_embedded_data); state.io = executeQuery(state.query, query_context, false, state.stage, may_have_embedded_data);
unknown_packet_in_send_data = query_context->getSettingsRef().unknown_packet_in_send_data;
after_check_cancelled.restart(); after_check_cancelled.restart();
after_send_progress.restart(); after_send_progress.restart();
if (state.io.out) if (state.io.out)
{ {
state.need_receive_data_for_insert = true; state.need_receive_data_for_insert = true;
processInsertQuery(settings); processInsertQuery();
} }
else if (state.need_receive_data_for_input) // It implies pipeline execution else if (state.need_receive_data_for_input) // It implies pipeline execution
{ {
@ -461,16 +433,17 @@ void TCPHandler::runImpl()
try try
{ {
if (exception && !state.temporary_tables_read) /// A query packet is always followed by one or more data packets.
query_context->initializeExternalTablesIfSet(); /// If some of those data packets are left, try to skip them.
if (exception && !state.empty() && !state.read_all_data)
skipData();
} }
catch (...) catch (...)
{ {
network_error = true; network_error = true;
LOG_WARNING(log, "Can't read external tables after query failure."); LOG_WARNING(log, "Can't skip data packets after query failure.");
} }
try try
{ {
/// QueryState should be cleared before QueryScope, since otherwise /// QueryState should be cleared before QueryScope, since otherwise
@ -501,75 +474,94 @@ void TCPHandler::runImpl()
} }
bool TCPHandler::readDataNext(size_t poll_interval, time_t receive_timeout) void TCPHandler::extractConnectionSettingsFromContext(const ContextPtr & context)
{
const auto & settings = context->getSettingsRef();
send_exception_with_stack_trace = settings.calculate_text_stack_trace;
send_timeout = settings.send_timeout;
receive_timeout = settings.receive_timeout;
poll_interval = settings.poll_interval;
idle_connection_timeout = settings.idle_connection_timeout;
interactive_delay = settings.interactive_delay;
sleep_in_send_tables_status = settings.sleep_in_send_tables_status_ms;
unknown_packet_in_send_data = settings.unknown_packet_in_send_data;
sleep_in_receive_cancel = settings.sleep_in_receive_cancel_ms;
}
bool TCPHandler::readDataNext()
{ {
Stopwatch watch(CLOCK_MONOTONIC_COARSE); Stopwatch watch(CLOCK_MONOTONIC_COARSE);
/// Poll interval should not be greater than receive_timeout
constexpr UInt64 min_timeout_ms = 5000; // 5 ms
UInt64 timeout_ms = std::max(min_timeout_ms, std::min(poll_interval * 1000000, static_cast<UInt64>(receive_timeout.totalMicroseconds())));
bool read_ok = false;
/// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down. /// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down.
while (true) while (true)
{ {
if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(poll_interval)) if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(timeout_ms))
{
/// If client disconnected.
if (in->eof())
{
LOG_INFO(log, "Client has dropped the connection, cancel the query.");
state.is_connection_closed = true;
break;
}
/// We accept and process data.
read_ok = receivePacket();
break; break;
}
/// Do we need to shut down? /// Do we need to shut down?
if (server.isCancelled()) if (server.isCancelled())
return false; break;
/** Have we waited for data for too long? /** Have we waited for data for too long?
* If we periodically poll, the receive_timeout of the socket itself does not work. * If we periodically poll, the receive_timeout of the socket itself does not work.
* Therefore, an additional check is added. * Therefore, an additional check is added.
*/ */
Float64 elapsed = watch.elapsedSeconds(); Float64 elapsed = watch.elapsedSeconds();
if (elapsed > static_cast<Float64>(receive_timeout)) if (elapsed > static_cast<Float64>(receive_timeout.totalSeconds()))
{ {
throw Exception(ErrorCodes::SOCKET_TIMEOUT, throw Exception(ErrorCodes::SOCKET_TIMEOUT,
"Timeout exceeded while receiving data from client. Waited for {} seconds, timeout is {} seconds.", "Timeout exceeded while receiving data from client. Waited for {} seconds, timeout is {} seconds.",
static_cast<size_t>(elapsed), receive_timeout); static_cast<size_t>(elapsed), receive_timeout.totalSeconds());
} }
} }
/// If client disconnected. if (read_ok)
if (in->eof()) sendLogs();
{ else
LOG_INFO(log, "Client has dropped the connection, cancel the query."); state.read_all_data = true;
state.is_connection_closed = true;
return false;
}
/// We accept and process data. And if they are over, then we leave. return read_ok;
if (!receivePacket())
return false;
sendLogs();
return true;
} }
std::tuple<size_t, int> TCPHandler::getReadTimeouts(const Settings & connection_settings) void TCPHandler::readData()
{ {
const auto receive_timeout = query_context->getSettingsRef().receive_timeout.value;
/// Poll interval should not be greater than receive_timeout
const size_t default_poll_interval = connection_settings.poll_interval * 1000000;
size_t current_poll_interval = static_cast<size_t>(receive_timeout.totalMicroseconds());
constexpr size_t min_poll_interval = 5000; // 5 ms
size_t poll_interval = std::max(min_poll_interval, std::min(default_poll_interval, current_poll_interval));
return std::make_tuple(poll_interval, receive_timeout.totalSeconds());
}
void TCPHandler::readData(const Settings & connection_settings)
{
auto [poll_interval, receive_timeout] = getReadTimeouts(connection_settings);
sendLogs(); sendLogs();
while (readDataNext(poll_interval, receive_timeout)) while (readDataNext())
; ;
} }
void TCPHandler::processInsertQuery(const Settings & connection_settings) void TCPHandler::skipData()
{
state.skipping_data = true;
SCOPE_EXIT({ state.skipping_data = false; });
while (readDataNext())
;
}
void TCPHandler::processInsertQuery()
{ {
/** Made above the rest of the lines, so that in case of `writePrefix` function throws an exception, /** Made above the rest of the lines, so that in case of `writePrefix` function throws an exception,
* client receive exception before sending data. * client receive exception before sending data.
@ -595,7 +587,7 @@ void TCPHandler::processInsertQuery(const Settings & connection_settings)
try try
{ {
readData(connection_settings); readData();
} }
catch (...) catch (...)
{ {
@ -634,7 +626,7 @@ void TCPHandler::processOrdinaryQuery()
break; break;
} }
if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay) if (after_send_progress.elapsed() / 1000 >= interactive_delay)
{ {
/// Some time passed. /// Some time passed.
after_send_progress.restart(); after_send_progress.restart();
@ -643,7 +635,7 @@ void TCPHandler::processOrdinaryQuery()
sendLogs(); sendLogs();
if (async_in.poll(query_context->getSettingsRef().interactive_delay / 1000)) if (async_in.poll(interactive_delay / 1000))
{ {
const auto block = async_in.read(); const auto block = async_in.read();
if (!block) if (!block)
@ -698,7 +690,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors()
CurrentMetrics::Increment query_thread_metric_increment{CurrentMetrics::QueryThread}; CurrentMetrics::Increment query_thread_metric_increment{CurrentMetrics::QueryThread};
Block block; Block block;
while (executor.pull(block, query_context->getSettingsRef().interactive_delay / 1000)) while (executor.pull(block, interactive_delay / 1000))
{ {
std::lock_guard lock(task_callback_mutex); std::lock_guard lock(task_callback_mutex);
@ -709,7 +701,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors()
break; break;
} }
if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay) if (after_send_progress.elapsed() / 1000 >= interactive_delay)
{ {
/// Some time passed and there is a progress. /// Some time passed and there is a progress.
after_send_progress.restart(); after_send_progress.restart();
@ -755,13 +747,14 @@ void TCPHandler::processTablesStatusRequest()
{ {
TablesStatusRequest request; TablesStatusRequest request;
request.read(*in, client_tcp_protocol_version); request.read(*in, client_tcp_protocol_version);
const auto session_context = session->sessionContext();
ContextPtr context_to_resolve_table_names = session->sessionContext() ? session->sessionContext() : server.context();
TablesStatusResponse response; TablesStatusResponse response;
for (const QualifiedTableName & table_name: request.tables) for (const QualifiedTableName & table_name: request.tables)
{ {
auto resolved_id = session_context->tryResolveStorageID({table_name.database, table_name.table}); auto resolved_id = context_to_resolve_table_names->tryResolveStorageID({table_name.database, table_name.table});
StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, session_context); StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, context_to_resolve_table_names);
if (!table) if (!table)
continue; continue;
@ -781,11 +774,10 @@ void TCPHandler::processTablesStatusRequest()
writeVarUInt(Protocol::Server::TablesStatusResponse, *out); writeVarUInt(Protocol::Server::TablesStatusResponse, *out);
/// For testing hedged requests /// For testing hedged requests
const Settings & settings = query_context->getSettingsRef(); if (sleep_in_send_tables_status.totalMilliseconds())
if (settings.sleep_in_send_tables_status_ms.totalMilliseconds())
{ {
out->next(); out->next();
std::chrono::milliseconds ms(settings.sleep_in_send_tables_status_ms.totalMilliseconds()); std::chrono::milliseconds ms(sleep_in_send_tables_status.totalMilliseconds());
std::this_thread::sleep_for(ms); std::this_thread::sleep_for(ms);
} }
@ -977,22 +969,21 @@ void TCPHandler::receiveHello()
(!user.empty() ? ", user: " + user : "") (!user.empty() ? ", user: " + user : "")
); );
if (user != USER_INTERSERVER_MARKER) auto & client_info = session->getClientInfo();
{ client_info.client_name = client_name;
auto & client_info = session->getClientInfo(); client_info.client_version_major = client_version_major;
client_info.interface = ClientInfo::Interface::TCP; client_info.client_version_minor = client_version_minor;
client_info.client_name = client_name; client_info.client_version_patch = client_version_patch;
client_info.client_version_major = client_version_major; client_info.client_tcp_protocol_version = client_tcp_protocol_version;
client_info.client_version_minor = client_version_minor;
client_info.client_version_patch = client_version_patch;
client_info.client_tcp_protocol_version = client_tcp_protocol_version;
session->setUser(user, password, socket().peerAddress()); is_interserver_mode = (user == USER_INTERSERVER_MARKER);
} if (is_interserver_mode)
else
{ {
receiveClusterNameAndSalt(); receiveClusterNameAndSalt();
return;
} }
session->authenticate(user, password, socket().peerAddress());
} }
@ -1039,8 +1030,11 @@ bool TCPHandler::receivePacket()
{ {
case Protocol::Client::IgnoredPartUUIDs: case Protocol::Client::IgnoredPartUUIDs:
/// Part uuids packet if any comes before query. /// Part uuids packet if any comes before query.
if (!state.empty() || state.part_uuids_to_ignore)
receiveUnexpectedIgnoredPartUUIDs();
receiveIgnoredPartUUIDs(); receiveIgnoredPartUUIDs();
return true; return true;
case Protocol::Client::Query: case Protocol::Client::Query:
if (!state.empty()) if (!state.empty())
receiveUnexpectedQuery(); receiveUnexpectedQuery();
@ -1049,8 +1043,10 @@ bool TCPHandler::receivePacket()
case Protocol::Client::Data: case Protocol::Client::Data:
case Protocol::Client::Scalar: case Protocol::Client::Scalar:
if (state.skipping_data)
return receiveUnexpectedData(false);
if (state.empty()) if (state.empty())
receiveUnexpectedData(); receiveUnexpectedData(true);
return receiveData(packet_type == Protocol::Client::Scalar); return receiveData(packet_type == Protocol::Client::Scalar);
case Protocol::Client::Ping: case Protocol::Client::Ping:
@ -1061,10 +1057,9 @@ bool TCPHandler::receivePacket()
case Protocol::Client::Cancel: case Protocol::Client::Cancel:
{ {
/// For testing connection collector. /// For testing connection collector.
const Settings & settings = query_context->getSettingsRef(); if (sleep_in_receive_cancel.totalMilliseconds())
if (settings.sleep_in_receive_cancel_ms.totalMilliseconds())
{ {
std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds()); std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds());
std::this_thread::sleep_for(ms); std::this_thread::sleep_for(ms);
} }
@ -1086,14 +1081,18 @@ bool TCPHandler::receivePacket()
} }
} }
void TCPHandler::receiveIgnoredPartUUIDs() void TCPHandler::receiveIgnoredPartUUIDs()
{ {
state.part_uuids = true; readVectorBinary(state.part_uuids_to_ignore.emplace(), *in);
std::vector<UUID> uuids; }
readVectorBinary(uuids, *in);
if (!uuids.empty())
query_context->getIgnoredPartUUIDs()->add(uuids); void TCPHandler::receiveUnexpectedIgnoredPartUUIDs()
{
std::vector<UUID> skip_part_uuids;
readVectorBinary(skip_part_uuids, *in);
throw NetException("Unexpected packet IgnoredPartUUIDs received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
} }
@ -1107,10 +1106,9 @@ String TCPHandler::receiveReadTaskResponseAssumeLocked()
{ {
state.is_cancelled = true; state.is_cancelled = true;
/// For testing connection collector. /// For testing connection collector.
const Settings & settings = query_context->getSettingsRef(); if (sleep_in_receive_cancel.totalMilliseconds())
if (settings.sleep_in_receive_cancel_ms.totalMilliseconds())
{ {
std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds()); std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds());
std::this_thread::sleep_for(ms); std::this_thread::sleep_for(ms);
} }
return {}; return {};
@ -1141,14 +1139,14 @@ void TCPHandler::receiveClusterNameAndSalt()
if (salt.empty()) if (salt.empty())
throw NetException("Empty salt is not allowed", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); throw NetException("Empty salt is not allowed", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
cluster_secret = query_context->getCluster(cluster)->getSecret(); cluster_secret = server.context()->getCluster(cluster)->getSecret();
} }
catch (const Exception & e) catch (const Exception & e)
{ {
try try
{ {
/// We try to send error information to the client. /// We try to send error information to the client.
sendException(e, session->getSettings().calculate_text_stack_trace); sendException(e, send_exception_with_stack_trace);
} }
catch (...) {} catch (...) {}
@ -1163,27 +1161,12 @@ void TCPHandler::receiveQuery()
state.is_empty = false; state.is_empty = false;
readStringBinary(state.query_id, *in); readStringBinary(state.query_id, *in);
// query_context = session->makeQueryContext(state.query_id);
/// Client info /// Read client info.
ClientInfo & client_info = query_context->getClientInfo(); ClientInfo client_info = session->getClientInfo();
if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO)
client_info.read(*in, client_tcp_protocol_version); client_info.read(*in, client_tcp_protocol_version);
/// For better support of old clients, that does not send ClientInfo.
if (client_info.query_kind == ClientInfo::QueryKind::NO_QUERY)
{
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.client_name = client_name;
client_info.client_version_major = client_version_major;
client_info.client_version_minor = client_version_minor;
client_info.client_version_patch = client_version_patch;
client_info.client_tcp_protocol_version = client_tcp_protocol_version;
}
/// Set fields, that are known apriori.
client_info.interface = ClientInfo::Interface::TCP;
/// Per query settings are also passed via TCP. /// Per query settings are also passed via TCP.
/// We need to check them before applying due to they can violate the settings constraints. /// We need to check them before applying due to they can violate the settings constraints.
auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS)
@ -1204,12 +1187,11 @@ void TCPHandler::receiveQuery()
readVarUInt(compression, *in); readVarUInt(compression, *in);
state.compression = static_cast<Protocol::Compression>(compression); state.compression = static_cast<Protocol::Compression>(compression);
last_block_in.compression = state.compression;
readStringBinary(state.query, *in); readStringBinary(state.query, *in);
/// It is OK to check only when query != INITIAL_QUERY, if (is_interserver_mode)
/// since only in that case the actions will be done.
if (!cluster.empty() && client_info.query_kind != ClientInfo::QueryKind::INITIAL_QUERY)
{ {
#if USE_SSL #if USE_SSL
std::string data(salt); std::string data(salt);
@ -1231,26 +1213,33 @@ void TCPHandler::receiveQuery()
/// i.e. when the INSERT is done with the global context (w/o user). /// i.e. when the INSERT is done with the global context (w/o user).
if (!client_info.initial_user.empty()) if (!client_info.initial_user.empty())
{ {
query_context->setUserWithoutCheckingPassword(client_info.initial_user, client_info.initial_address); LOG_DEBUG(log, "User (initial): {}", client_info.initial_user);
LOG_DEBUG(log, "User (initial): {}", query_context->getUserName()); session->authenticate(AlwaysAllowCredentials{client_info.initial_user}, client_info.initial_address);
} }
/// No need to update connection_context, since it does not requires user (it will not be used for query execution)
#else #else
throw Exception( throw Exception(
"Inter-server secret support is disabled, because ClickHouse was built without SSL library", "Inter-server secret support is disabled, because ClickHouse was built without SSL library",
ErrorCodes::SUPPORT_IS_DISABLED); ErrorCodes::SUPPORT_IS_DISABLED);
#endif #endif
} }
else
{ query_context = session->makeQueryContext(std::move(client_info));
query_context->setInitialRowPolicy();
} /// Sets the default database if it wasn't set earlier for the session context.
if (!default_database.empty() && !session->sessionContext())
query_context->setCurrentDatabase(default_database);
if (state.part_uuids_to_ignore)
query_context->getIgnoredPartUUIDs()->add(*state.part_uuids_to_ignore);
query_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); });
/// ///
/// Settings /// Settings
/// ///
auto settings_changes = passed_settings.changes(); auto settings_changes = passed_settings.changes();
if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY) auto query_kind = query_context->getClientInfo().query_kind;
if (query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{ {
/// Throw an exception if the passed settings violate the constraints. /// Throw an exception if the passed settings violate the constraints.
query_context->checkSettingsConstraints(settings_changes); query_context->checkSettingsConstraints(settings_changes);
@ -1262,40 +1251,24 @@ void TCPHandler::receiveQuery()
} }
query_context->applySettingsChanges(settings_changes); query_context->applySettingsChanges(settings_changes);
/// Use the received query id, or generate a random default. It is convenient
/// to also generate the default OpenTelemetry trace id at the same time, and
/// set the trace parent.
/// Notes:
/// 1) ClientInfo might contain upstream trace id, so we decide whether to use
/// the default ids after we have received the ClientInfo.
/// 2) There is the opentelemetry_start_trace_probability setting that
/// controls when we start a new trace. It can be changed via Native protocol,
/// so we have to apply the changes first.
query_context->setCurrentQueryId(state.query_id);
/// Disable function name normalization when it's a secondary query, because queries are either /// Disable function name normalization when it's a secondary query, because queries are either
/// already normalized on initiator node, or not normalized and should remain unnormalized for /// already normalized on initiator node, or not normalized and should remain unnormalized for
/// compatibility. /// compatibility.
if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY) if (query_kind == ClientInfo::QueryKind::SECONDARY_QUERY)
{ {
query_context->setSetting("normalize_function_names", Field(0)); query_context->setSetting("normalize_function_names", Field(0));
} }
// Use the received query id, or generate a random default. It is convenient
// to also generate the default OpenTelemetry trace id at the same time, and
// set the trace parent.
// Why is this done here and not earlier:
// 1) ClientInfo might contain upstream trace id, so we decide whether to use
// the default ids after we have received the ClientInfo.
// 2) There is the opentelemetry_start_trace_probability setting that
// controls when we start a new trace. It can be changed via Native protocol,
// so we have to apply the changes first.
query_context->setCurrentQueryId(state.query_id);
// Set parameters of initial query.
if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{
/// 'Current' fields was set at receiveHello.
client_info.initial_user = client_info.current_user;
client_info.initial_query_id = client_info.current_query_id;
client_info.initial_address = client_info.current_address;
}
/// Sync timeouts on client and server during current query to avoid dangling queries on server
/// NOTE: We use settings.send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter),
/// because settings.send_timeout is client-side setting which has opposite meaning on the server side.
/// NOTE: these settings are applied only for current connection (not for distributed tables' connections)
const Settings & settings = query_context->getSettingsRef();
state.timeout_setter = std::make_unique<TimeoutSetter>(socket(), settings.receive_timeout, settings.send_timeout);
} }
void TCPHandler::receiveUnexpectedQuery() void TCPHandler::receiveUnexpectedQuery()
@ -1320,7 +1293,10 @@ void TCPHandler::receiveUnexpectedQuery()
readStringBinary(skip_hash, *in, 32); readStringBinary(skip_hash, *in, 32);
readVarUInt(skip_uint_64, *in); readVarUInt(skip_uint_64, *in);
readVarUInt(skip_uint_64, *in); readVarUInt(skip_uint_64, *in);
last_block_in.compression = static_cast<Protocol::Compression>(skip_uint_64);
readStringBinary(skip_string, *in); readStringBinary(skip_string, *in);
throw NetException("Unexpected packet Query received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); throw NetException("Unexpected packet Query received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
@ -1337,73 +1313,77 @@ bool TCPHandler::receiveData(bool scalar)
/// Read one block from the network and write it down /// Read one block from the network and write it down
Block block = state.block_in->read(); Block block = state.block_in->read();
if (block) if (!block)
{ {
if (scalar) state.read_all_data = true;
{ return false;
/// Scalar value }
query_context->addScalar(temporary_id.table_name, block);
}
else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input)
{
/// Data for external tables
auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal); if (scalar)
StoragePtr storage; {
/// If such a table does not exist, create it. /// Scalar value
if (resolved) query_context->addScalar(temporary_id.table_name, block);
{ }
storage = DatabaseCatalog::instance().getTable(resolved, query_context); else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input)
} {
else /// Data for external tables
{
NamesAndTypesList columns = block.getNamesAndTypesList();
auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {});
storage = temporary_table.getTable();
query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table));
}
auto metadata_snapshot = storage->getInMemoryMetadataPtr();
/// The data will be written directly to the table.
auto temporary_table_out = std::make_shared<PushingToSinkBlockOutputStream>(storage->write(ASTPtr(), metadata_snapshot, query_context));
temporary_table_out->write(block);
temporary_table_out->writeSuffix();
} auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal);
else if (state.need_receive_data_for_input) StoragePtr storage;
/// If such a table does not exist, create it.
if (resolved)
{ {
/// 'input' table function. storage = DatabaseCatalog::instance().getTable(resolved, query_context);
state.block_for_input = block;
} }
else else
{ {
/// INSERT query. NamesAndTypesList columns = block.getNamesAndTypesList();
state.io.out->write(block); auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {});
storage = temporary_table.getTable();
query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table));
} }
return true; auto metadata_snapshot = storage->getInMemoryMetadataPtr();
/// The data will be written directly to the table.
auto temporary_table_out = std::make_shared<PushingToSinkBlockOutputStream>(storage->write(ASTPtr(), metadata_snapshot, query_context));
temporary_table_out->write(block);
temporary_table_out->writeSuffix();
}
else if (state.need_receive_data_for_input)
{
/// 'input' table function.
state.block_for_input = block;
} }
else else
return false; {
/// INSERT query.
state.io.out->write(block);
}
return true;
} }
void TCPHandler::receiveUnexpectedData()
bool TCPHandler::receiveUnexpectedData(bool throw_exception)
{ {
String skip_external_table_name; String skip_external_table_name;
readStringBinary(skip_external_table_name, *in); readStringBinary(skip_external_table_name, *in);
std::shared_ptr<ReadBuffer> maybe_compressed_in; std::shared_ptr<ReadBuffer> maybe_compressed_in;
if (last_block_in.compression == Protocol::Compression::Enable) if (last_block_in.compression == Protocol::Compression::Enable)
maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true); maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true);
else else
maybe_compressed_in = in; maybe_compressed_in = in;
auto skip_block_in = std::make_shared<NativeBlockInputStream>( auto skip_block_in = std::make_shared<NativeBlockInputStream>(*maybe_compressed_in, client_tcp_protocol_version);
*maybe_compressed_in, bool read_ok = skip_block_in->read();
last_block_in.header,
client_tcp_protocol_version);
skip_block_in->read(); if (!read_ok)
throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); state.read_all_data = true;
if (throw_exception)
throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
return read_ok;
} }
void TCPHandler::initBlockInput() void TCPHandler::initBlockInput()
@ -1424,9 +1404,6 @@ void TCPHandler::initBlockInput()
else if (state.need_receive_data_for_input) else if (state.need_receive_data_for_input)
header = state.input_header; header = state.input_header;
last_block_in.header = header;
last_block_in.compression = state.compression;
state.block_in = std::make_shared<NativeBlockInputStream>( state.block_in = std::make_shared<NativeBlockInputStream>(
*state.maybe_compressed_in, *state.maybe_compressed_in,
header, header,
@ -1439,10 +1416,9 @@ void TCPHandler::initBlockOutput(const Block & block)
{ {
if (!state.block_out) if (!state.block_out)
{ {
const Settings & query_settings = query_context->getSettingsRef();
if (!state.maybe_compressed_out) if (!state.maybe_compressed_out)
{ {
const Settings & query_settings = query_context->getSettingsRef();
std::string method = Poco::toUpper(query_settings.network_compression_method.toString()); std::string method = Poco::toUpper(query_settings.network_compression_method.toString());
std::optional<int> level; std::optional<int> level;
if (method == "ZSTD") if (method == "ZSTD")
@ -1463,7 +1439,7 @@ void TCPHandler::initBlockOutput(const Block & block)
*state.maybe_compressed_out, *state.maybe_compressed_out,
client_tcp_protocol_version, client_tcp_protocol_version,
block.cloneEmpty(), block.cloneEmpty(),
!session->getSettings().low_cardinality_allow_in_native_format); !query_settings.low_cardinality_allow_in_native_format);
} }
} }
@ -1472,11 +1448,12 @@ void TCPHandler::initLogsBlockOutput(const Block & block)
if (!state.logs_block_out) if (!state.logs_block_out)
{ {
/// Use uncompressed stream since log blocks usually contain only one row /// Use uncompressed stream since log blocks usually contain only one row
const Settings & query_settings = query_context->getSettingsRef();
state.logs_block_out = std::make_shared<NativeBlockOutputStream>( state.logs_block_out = std::make_shared<NativeBlockOutputStream>(
*out, *out,
client_tcp_protocol_version, client_tcp_protocol_version,
block.cloneEmpty(), block.cloneEmpty(),
!session->getSettings().low_cardinality_allow_in_native_format); !query_settings.low_cardinality_allow_in_native_format);
} }
} }
@ -1486,7 +1463,7 @@ bool TCPHandler::isQueryCancelled()
if (state.is_cancelled || state.sent_all_data) if (state.is_cancelled || state.sent_all_data)
return true; return true;
if (after_check_cancelled.elapsed() / 1000 < query_context->getSettingsRef().interactive_delay) if (after_check_cancelled.elapsed() / 1000 < interactive_delay)
return false; return false;
after_check_cancelled.restart(); after_check_cancelled.restart();
@ -1514,10 +1491,9 @@ bool TCPHandler::isQueryCancelled()
state.is_cancelled = true; state.is_cancelled = true;
/// For testing connection collector. /// For testing connection collector.
{ {
const Settings & settings = query_context->getSettingsRef(); if (sleep_in_receive_cancel.totalMilliseconds())
if (settings.sleep_in_receive_cancel_ms.totalMilliseconds())
{ {
std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds()); std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds());
std::this_thread::sleep_for(ms); std::this_thread::sleep_for(ms);
} }
} }
@ -1555,11 +1531,10 @@ void TCPHandler::sendData(const Block & block)
writeStringBinary("", *out); writeStringBinary("", *out);
/// For testing hedged requests /// For testing hedged requests
const Settings & settings = query_context->getSettingsRef(); if (block.rows() > 0 && query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds())
if (block.rows() > 0 && settings.sleep_in_send_data_ms.totalMilliseconds())
{ {
out->next(); out->next();
std::chrono::milliseconds ms(settings.sleep_in_send_data_ms.totalMilliseconds()); std::chrono::milliseconds ms(query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds());
std::this_thread::sleep_for(ms); std::this_thread::sleep_for(ms);
} }

View File

@ -27,7 +27,9 @@ namespace DB
{ {
class Session; class Session;
struct Settings;
class ColumnsDescription; class ColumnsDescription;
struct BlockStreamProfileInfo;
/// State of query processing. /// State of query processing.
struct QueryState struct QueryState
@ -66,11 +68,11 @@ struct QueryState
bool sent_all_data = false; bool sent_all_data = false;
/// Request requires data from the client (INSERT, but not INSERT SELECT). /// Request requires data from the client (INSERT, but not INSERT SELECT).
bool need_receive_data_for_insert = false; bool need_receive_data_for_insert = false;
/// Temporary tables read /// Data was read.
bool temporary_tables_read = false; bool read_all_data = false;
/// A state got uuids to exclude from a query /// A state got uuids to exclude from a query
bool part_uuids = false; std::optional<std::vector<UUID>> part_uuids_to_ignore;
/// Request requires data from client for function input() /// Request requires data from client for function input()
bool need_receive_data_for_input = false; bool need_receive_data_for_input = false;
@ -79,6 +81,9 @@ struct QueryState
/// sample block from StorageInput /// sample block from StorageInput
Block input_header; Block input_header;
/// If true, the data packets will be skipped instead of reading. Used to recover after errors.
bool skipping_data = false;
/// To output progress, the difference after the previous sending of progress. /// To output progress, the difference after the previous sending of progress.
Progress progress; Progress progress;
@ -100,7 +105,6 @@ struct QueryState
struct LastBlockInputParameters struct LastBlockInputParameters
{ {
Protocol::Compression compression = Protocol::Compression::Disable; Protocol::Compression compression = Protocol::Compression::Disable;
Block header;
}; };
class TCPHandler : public Poco::Net::TCPServerConnection class TCPHandler : public Poco::Net::TCPServerConnection
@ -133,11 +137,20 @@ private:
UInt64 client_version_patch = 0; UInt64 client_version_patch = 0;
UInt64 client_tcp_protocol_version = 0; UInt64 client_tcp_protocol_version = 0;
/// Connection settings, which are extracted from a context.
bool send_exception_with_stack_trace = true;
Poco::Timespan send_timeout = DBMS_DEFAULT_SEND_TIMEOUT_SEC;
Poco::Timespan receive_timeout = DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC;
UInt64 poll_interval = DBMS_DEFAULT_POLL_INTERVAL;
UInt64 idle_connection_timeout = 3600;
UInt64 interactive_delay = 100000;
Poco::Timespan sleep_in_send_tables_status;
UInt64 unknown_packet_in_send_data = 0;
Poco::Timespan sleep_in_receive_cancel;
std::unique_ptr<Session> session; std::unique_ptr<Session> session;
ContextMutablePtr query_context; ContextMutablePtr query_context;
size_t unknown_packet_in_send_data = 0;
/// Streams for reading/writing from/to client connection socket. /// Streams for reading/writing from/to client connection socket.
std::shared_ptr<ReadBuffer> in; std::shared_ptr<ReadBuffer> in;
std::shared_ptr<WriteBuffer> out; std::shared_ptr<WriteBuffer> out;
@ -149,6 +162,7 @@ private:
String default_database; String default_database;
/// For inter-server secret (remote_server.*.secret) /// For inter-server secret (remote_server.*.secret)
bool is_interserver_mode = false;
String salt; String salt;
String cluster; String cluster;
String cluster_secret; String cluster_secret;
@ -168,6 +182,8 @@ private:
void runImpl(); void runImpl();
void extractConnectionSettingsFromContext(const ContextPtr & context);
bool receiveProxyHeader(); bool receiveProxyHeader();
void receiveHello(); void receiveHello();
bool receivePacket(); bool receivePacket();
@ -175,18 +191,19 @@ private:
void receiveIgnoredPartUUIDs(); void receiveIgnoredPartUUIDs();
String receiveReadTaskResponseAssumeLocked(); String receiveReadTaskResponseAssumeLocked();
bool receiveData(bool scalar); bool receiveData(bool scalar);
bool readDataNext(size_t poll_interval, time_t receive_timeout); bool readDataNext();
void readData(const Settings & connection_settings); void readData();
void skipData();
void receiveClusterNameAndSalt(); void receiveClusterNameAndSalt();
std::tuple<size_t, int> getReadTimeouts(const Settings & connection_settings);
[[noreturn]] void receiveUnexpectedData(); bool receiveUnexpectedData(bool throw_exception = true);
[[noreturn]] void receiveUnexpectedQuery(); [[noreturn]] void receiveUnexpectedQuery();
[[noreturn]] void receiveUnexpectedIgnoredPartUUIDs();
[[noreturn]] void receiveUnexpectedHello(); [[noreturn]] void receiveUnexpectedHello();
[[noreturn]] void receiveUnexpectedTablesStatusRequest(); [[noreturn]] void receiveUnexpectedTablesStatusRequest();
/// Process INSERT query /// Process INSERT query
void processInsertQuery(const Settings & connection_settings); void processInsertQuery();
/// Process a request that does not require the receiving of data blocks from the client /// Process a request that does not require the receiving of data blocks from the client
void processOrdinaryQuery(); void processOrdinaryQuery();

View File

@ -61,9 +61,8 @@ void TableFunctionMySQL::parseArguments(const ASTPtr & ast_function, ContextPtr
user_name = args[3]->as<ASTLiteral &>().value.safeGet<String>(); user_name = args[3]->as<ASTLiteral &>().value.safeGet<String>();
password = args[4]->as<ASTLiteral &>().value.safeGet<String>(); password = args[4]->as<ASTLiteral &>().value.safeGet<String>();
const auto & settings = context->getSettingsRef();
/// Split into replicas if needed. 3306 is the default MySQL port number /// Split into replicas if needed. 3306 is the default MySQL port number
const size_t max_addresses = settings.glob_expansion_max_elements; size_t max_addresses = context->getSettingsRef().glob_expansion_max_elements;
auto addresses = parseRemoteDescriptionForExternalDatabase(host_port, max_addresses, 3306); auto addresses = parseRemoteDescriptionForExternalDatabase(host_port, max_addresses, 3306);
pool.emplace(remote_database_name, addresses, user_name, password); pool.emplace(remote_database_name, addresses, user_name, password);

View File

@ -24,3 +24,4 @@ def test_different_versions(start_cluster):
node.query("SELECT 1", settings={'max_concurrent_queries_for_user': 1}) node.query("SELECT 1", settings={'max_concurrent_queries_for_user': 1})
assert node.contains_in_log('Too many simultaneous queries for user') assert node.contains_in_log('Too many simultaneous queries for user')
assert not node.contains_in_log('Unknown packet') assert not node.contains_in_log('Unknown packet')
assert not node.contains_in_log('Unexpected packet')

View File

@ -1,8 +1,20 @@
===http=== ===http===
{"query":"select 1 from remote('127.0.0.2', system, one) format Null\n","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"DESC TABLE system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"DESC TABLE system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"SELECT 1 FROM system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"DESC TABLE system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"query":"DESC TABLE system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"query":"SELECT 1 FROM system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"query":"select 1 from remote('127.0.0.2', system, one) format Null\n","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"total spans":"4","unique spans":"4","unique non-zero parent spans":"3"} {"total spans":"4","unique spans":"4","unique non-zero parent spans":"3"}
{"initial query spans with proper parent":"1"} {"initial query spans with proper parent":"1"}
{"unique non-empty tracestate values":"1"} {"unique non-empty tracestate values":"1"}
===native=== ===native===
{"query":"select * from url('http:\/\/127.0.0.2:8123\/?query=select%201%20format%20Null', CSV, 'a int')","status":"QueryFinish","tracestate":"another custom state","sorted_by_start_time":1}
{"query":"select 1 format Null\n","status":"QueryFinish","tracestate":"another custom state","sorted_by_start_time":1}
{"query":"select 1 format Null\n","query_status":"QueryFinish","tracestate":"another custom state","sorted_by_finish_time":1}
{"query":"select * from url('http:\/\/127.0.0.2:8123\/?query=select%201%20format%20Null', CSV, 'a int')","query_status":"QueryFinish","tracestate":"another custom state","sorted_by_finish_time":1}
{"total spans":"2","unique spans":"2","unique non-zero parent spans":"2"} {"total spans":"2","unique spans":"2","unique non-zero parent spans":"2"}
{"initial query spans with proper parent":"1"} {"initial query spans with proper parent":"1"}
{"unique non-empty tracestate values":"1"} {"unique non-empty tracestate values":"1"}

View File

@ -12,6 +12,28 @@ function check_log
${CLICKHOUSE_CLIENT} --format=JSONEachRow -nq " ${CLICKHOUSE_CLIENT} --format=JSONEachRow -nq "
system flush logs; system flush logs;
-- Show queries sorted by start time.
select attribute['db.statement'] as query,
attribute['clickhouse.query_status'] as status,
attribute['clickhouse.tracestate'] as tracestate,
1 as sorted_by_start_time
from system.opentelemetry_span_log
where trace_id = reinterpretAsUUID(reverse(unhex('$trace_id')))
and operation_name = 'query'
order by start_time_us
;
-- Show queries sorted by finish time.
select attribute['db.statement'] as query,
attribute['clickhouse.query_status'] as query_status,
attribute['clickhouse.tracestate'] as tracestate,
1 as sorted_by_finish_time
from system.opentelemetry_span_log
where trace_id = reinterpretAsUUID(reverse(unhex('$trace_id')))
and operation_name = 'query'
order by finish_time_us
;
-- Check the number of query spans with given trace id, to verify it was -- Check the number of query spans with given trace id, to verify it was
-- propagated. -- propagated.
select count(*) "'"'"total spans"'"'", select count(*) "'"'"total spans"'"'",
@ -89,10 +111,10 @@ check_log
echo "===sampled===" echo "===sampled==="
query_id=$(${CLICKHOUSE_CLIENT} -q "select lower(hex(reverse(reinterpretAsString(generateUUIDv4()))))") query_id=$(${CLICKHOUSE_CLIENT} -q "select lower(hex(reverse(reinterpretAsString(generateUUIDv4()))))")
for i in {1..200} for i in {1..20}
do do
${CLICKHOUSE_CLIENT} \ ${CLICKHOUSE_CLIENT} \
--opentelemetry_start_trace_probability=0.1 \ --opentelemetry_start_trace_probability=0.5 \
--query_id "$query_id-$i" \ --query_id "$query_id-$i" \
--query "select 1 from remote('127.0.0.2', system, one) format Null" \ --query "select 1 from remote('127.0.0.2', system, one) format Null" \
& &
@ -108,8 +130,8 @@ wait
${CLICKHOUSE_CLIENT} -q "system flush logs" ${CLICKHOUSE_CLIENT} -q "system flush logs"
${CLICKHOUSE_CLIENT} -q " ${CLICKHOUSE_CLIENT} -q "
-- expect 200 * 0.1 = 20 sampled events on average -- expect 20 * 0.5 = 10 sampled events on average
select if(count() > 1 and count() < 50, 'OK', 'Fail') select if(2 <= count() and count() <= 18, 'OK', 'Fail')
from system.opentelemetry_span_log from system.opentelemetry_span_log
where operation_name = 'query' where operation_name = 'query'
and parent_span_id = 0 -- only account for the initial queries and parent_span_id = 0 -- only account for the initial queries