Merge pull request #26864 from vitlibar/refactor-sessions

Introduce sessions
This commit is contained in:
Vitaly Baranov 2021-08-19 01:38:51 +03:00 committed by GitHub
commit 65ee9a1272
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 983 additions and 714 deletions

View File

@ -381,7 +381,7 @@ void LocalServer::processQueries()
context->makeSessionContext();
context->makeQueryContext();
context->setUser("default", "", Poco::Net::SocketAddress{});
context->authenticate("default", "", Poco::Net::SocketAddress{});
context->setCurrentQueryId("");
applyCmdSettings(context);

View File

@ -53,6 +53,7 @@
#include <Interpreters/ExternalLoaderXMLConfigRepository.h>
#include <Interpreters/InterserverCredentials.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Interpreters/Session.h>
#include <Access/AccessControlManager.h>
#include <Storages/StorageReplicatedMergeTree.h>
#include <Storages/System/attachSystemTables.h>
@ -1428,7 +1429,7 @@ if (ThreadFuzzer::instance().isEffective())
/// Must be done after initialization of `servers`, because async_metrics will access `servers` variable from its thread.
async_metrics.start();
global_context->enableNamedSessions();
Session::startupNamedSessions();
{
String level_str = config().getString("text_log.level", "");

View File

@ -70,6 +70,7 @@ public:
/// Returns the current user. The function can return nullptr.
UserPtr getUser() const;
String getUserName() const;
std::optional<UUID> getUserID() const { return getParams().user_id; }
/// Returns information about current and enabled roles.
std::shared_ptr<const EnabledRolesInfo> getRolesInfo() const;

View File

@ -26,6 +26,8 @@ protected:
String user_name;
};
/// Does not check the password/credentials and that the specified host is allowed.
/// (Used only internally in cluster, if the secret matches)
class AlwaysAllowCredentials
: public Credentials
{

View File

@ -5,6 +5,7 @@
#include <Poco/Net/HTTPRequest.h>
#include <Poco/URI.h>
#include <filesystem>
#include <thread>
namespace fs = std::filesystem;

View File

@ -2,8 +2,7 @@
#include <Core/MySQL/PacketsConnection.h>
#include <Poco/RandomStream.h>
#include <Poco/SHA1Engine.h>
#include <Access/User.h>
#include <Access/AccessControlManager.h>
#include <Interpreters/Session.h>
#include <common/logger_useful.h>
#include <Common/OpenSSLHelpers.h>
@ -73,7 +72,7 @@ Native41::Native41(const String & password, const String & auth_plugin_data)
}
void Native41::authenticate(
const String & user_name, std::optional<String> auth_response, ContextMutablePtr context,
const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool, const Poco::Net::SocketAddress & address)
{
if (!auth_response)
@ -86,7 +85,7 @@ void Native41::authenticate(
if (auth_response->empty())
{
context->setUser(user_name, "", address);
session.authenticate(user_name, "", address);
return;
}
@ -96,9 +95,7 @@ void Native41::authenticate(
+ " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
ErrorCodes::UNKNOWN_EXCEPTION);
auto user = context->getAccessControlManager().read<User>(user_name);
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
Poco::SHA1Engine::Digest double_sha1_value = session.getPasswordDoubleSHA1(user_name);
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
Poco::SHA1Engine engine;
@ -111,7 +108,7 @@ void Native41::authenticate(
{
password_sha1[i] = digest[i] ^ static_cast<unsigned char>((*auth_response)[i]);
}
context->setUser(user_name, password_sha1, address);
session.authenticate(user_name, password_sha1, address);
}
#if USE_SSL
@ -136,7 +133,7 @@ Sha256Password::Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logg
}
void Sha256Password::authenticate(
const String & user_name, std::optional<String> auth_response, ContextMutablePtr context,
const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address)
{
if (!auth_response)
@ -231,7 +228,7 @@ void Sha256Password::authenticate(
password.pop_back();
}
context->setUser(user_name, password, address);
session.authenticate(user_name, password, address);
}
#endif

View File

@ -15,6 +15,7 @@
namespace DB
{
class Session;
namespace MySQLProtocol
{
@ -32,7 +33,7 @@ public:
virtual String getAuthPluginData() = 0;
virtual void authenticate(
const String & user_name, std::optional<String> auth_response, ContextMutablePtr context,
const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
};
@ -49,7 +50,7 @@ public:
String getAuthPluginData() override { return scramble; }
void authenticate(
const String & user_name, std::optional<String> auth_response, ContextMutablePtr context,
const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override;
private:
@ -69,7 +70,7 @@ public:
String getAuthPluginData() override { return scramble; }
void authenticate(
const String & user_name, std::optional<String> auth_response, ContextMutablePtr context,
const String & user_name, Session & session, std::optional<String> auth_response,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override;
private:

View File

@ -1,13 +1,11 @@
#pragma once
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <functional>
#include <Interpreters/Context.h>
#include <IO/ReadBuffer.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Session.h>
#include <common/logger_useful.h>
#include <Poco/Format.h>
#include <Poco/RegularExpression.h>
@ -803,12 +801,13 @@ protected:
static void setPassword(
const String & user_name,
const String & password,
ContextMutablePtr context,
Session & session,
Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address)
{
try {
context->setUser(user_name, password, address);
try
{
session.authenticate(user_name, password, address);
}
catch (const Exception &)
{
@ -822,7 +821,7 @@ protected:
public:
virtual void authenticate(
const String & user_name,
ContextMutablePtr context,
Session & session,
Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address) = 0;
@ -836,11 +835,11 @@ class NoPasswordAuth : public AuthenticationMethod
public:
void authenticate(
const String & user_name,
ContextMutablePtr context,
Session & session,
Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address) override
{
setPassword(user_name, "", context, mt, address);
return setPassword(user_name, "", session, mt, address);
}
Authentication::Type getType() const override
@ -854,7 +853,7 @@ class CleartextPasswordAuth : public AuthenticationMethod
public:
void authenticate(
const String & user_name,
ContextMutablePtr context,
Session & session,
Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address) override
{
@ -864,7 +863,7 @@ public:
if (type == Messaging::FrontMessageType::PASSWORD_MESSAGE)
{
std::unique_ptr<Messaging::PasswordMessage> password = mt.receive<Messaging::PasswordMessage>();
setPassword(user_name, password->password, context, mt, address);
return setPassword(user_name, password->password, session, mt, address);
}
else
throw Exception(
@ -897,16 +896,15 @@ public:
void authenticate(
const String & user_name,
ContextMutablePtr context,
Session & session,
Messaging::MessageTransport & mt,
const Poco::Net::SocketAddress & address)
{
auto user = context->getAccessControlManager().read<User>(user_name);
Authentication::Type user_auth_type = user->authentication.getType();
Authentication::Type user_auth_type = session.getAuthenticationType(user_name);
if (type_to_method.find(user_auth_type) != type_to_method.end())
{
type_to_method[user_auth_type]->authenticate(user_name, context, mt, address);
type_to_method[user_auth_type]->authenticate(user_name, session, mt, address);
mt.send(Messaging::AuthenticationOk(), true);
LOG_DEBUG(log, "Authentication for user {} was successful.", user_name);
return;

View File

@ -255,7 +255,7 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory)
/// We should set user info even for the case when the dictionary is loaded in-process (without TCP communication).
if (configuration.is_local)
{
context_copy->setUser(configuration.user, configuration.password, Poco::Net::SocketAddress("127.0.0.1", 0));
context_copy->authenticate(configuration.user, configuration.password, Poco::Net::SocketAddress("127.0.0.1", 0));
context_copy = copyContextAndApplySettings(config_prefix, context_copy, config);
}

View File

@ -12,6 +12,7 @@
#include <Common/UnicodeBar.h>
#include <Common/TerminalSize.h>
#include <IO/Operators.h>
#include <IO/Progress.h>
namespace ProfileEvents

View File

@ -100,7 +100,6 @@ namespace CurrentMetrics
extern const Metric BackgroundMessageBrokerSchedulePoolTask;
}
namespace DB
{
@ -115,189 +114,11 @@ namespace ErrorCodes
extern const int THERE_IS_NO_QUERY;
extern const int NO_ELEMENTS_IN_CONFIG;
extern const int TABLE_SIZE_EXCEEDS_MAX_DROP_SIZE_LIMIT;
extern const int SESSION_NOT_FOUND;
extern const int SESSION_IS_LOCKED;
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
class NamedSessions
{
public:
using Key = NamedSessionKey;
~NamedSessions()
{
try
{
{
std::lock_guard lock{mutex};
quit = true;
}
cond.notify_one();
thread.join();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
/// Find existing session or create a new.
std::shared_ptr<NamedSession> acquireSession(
const String & session_id,
ContextMutablePtr context,
std::chrono::steady_clock::duration timeout,
bool throw_if_not_found)
{
std::unique_lock lock(mutex);
auto & user_name = context->client_info.current_user;
if (user_name.empty())
throw Exception("Empty user name.", ErrorCodes::LOGICAL_ERROR);
Key key(user_name, session_id);
auto it = sessions.find(key);
if (it == sessions.end())
{
if (throw_if_not_found)
throw Exception("Session not found.", ErrorCodes::SESSION_NOT_FOUND);
/// Create a new session from current context.
it = sessions.insert(std::make_pair(key, std::make_shared<NamedSession>(key, context, timeout, *this))).first;
}
else if (it->second->key.first != context->client_info.current_user)
{
throw Exception("Session belongs to a different user", ErrorCodes::SESSION_IS_LOCKED);
}
/// Use existing session.
const auto & session = it->second;
if (!session.unique())
throw Exception("Session is locked by a concurrent client.", ErrorCodes::SESSION_IS_LOCKED);
session->context->client_info = context->client_info;
return session;
}
void releaseSession(NamedSession & session)
{
std::unique_lock lock(mutex);
scheduleCloseSession(session, lock);
}
private:
class SessionKeyHash
{
public:
size_t operator()(const Key & key) const
{
SipHash hash;
hash.update(key.first);
hash.update(key.second);
return hash.get64();
}
};
/// TODO it's very complicated. Make simple std::map with time_t or boost::multi_index.
using Container = std::unordered_map<Key, std::shared_ptr<NamedSession>, SessionKeyHash>;
using CloseTimes = std::deque<std::vector<Key>>;
Container sessions;
CloseTimes close_times;
std::chrono::steady_clock::duration close_interval = std::chrono::seconds(1);
std::chrono::steady_clock::time_point close_cycle_time = std::chrono::steady_clock::now();
UInt64 close_cycle = 0;
void scheduleCloseSession(NamedSession & session, std::unique_lock<std::mutex> &)
{
/// Push it on a queue of sessions to close, on a position corresponding to the timeout.
/// (timeout is measured from current moment of time)
const UInt64 close_index = session.timeout / close_interval + 1;
const auto new_close_cycle = close_cycle + close_index;
if (session.close_cycle != new_close_cycle)
{
session.close_cycle = new_close_cycle;
if (close_times.size() < close_index + 1)
close_times.resize(close_index + 1);
close_times[close_index].emplace_back(session.key);
}
}
void cleanThread()
{
setThreadName("SessionCleaner");
std::unique_lock lock{mutex};
while (true)
{
auto interval = closeSessions(lock);
if (cond.wait_for(lock, interval, [this]() -> bool { return quit; }))
break;
}
}
/// Close sessions, that has been expired. Returns how long to wait for next session to be expired, if no new sessions will be added.
std::chrono::steady_clock::duration closeSessions(std::unique_lock<std::mutex> & lock)
{
const auto now = std::chrono::steady_clock::now();
/// The time to close the next session did not come
if (now < close_cycle_time)
return close_cycle_time - now; /// Will sleep until it comes.
const auto current_cycle = close_cycle;
++close_cycle;
close_cycle_time = now + close_interval;
if (close_times.empty())
return close_interval;
auto & sessions_to_close = close_times.front();
for (const auto & key : sessions_to_close)
{
const auto session = sessions.find(key);
if (session != sessions.end() && session->second->close_cycle <= current_cycle)
{
if (!session->second.unique())
{
/// Skip but move it to close on the next cycle.
session->second->timeout = std::chrono::steady_clock::duration{0};
scheduleCloseSession(*session->second, lock);
}
else
sessions.erase(session);
}
}
close_times.pop_front();
return close_interval;
}
std::mutex mutex;
std::condition_variable cond;
std::atomic<bool> quit{false};
ThreadFromGlobalPool thread{&NamedSessions::cleanThread, this};
};
void NamedSession::release()
{
parent.releaseSession(*this);
}
/** Set of known objects (environment), that could be used in query.
* Shared (global) part. Order of members (especially, order of destruction) is very important.
*/
@ -399,7 +220,6 @@ struct ContextSharedPart
RemoteHostFilter remote_host_filter; /// Allowed URL from config.xml
std::optional<TraceCollector> trace_collector; /// Thread collecting traces from threads executing queries
std::optional<NamedSessions> named_sessions; /// Controls named HTTP sessions.
/// Clusters for distributed tables
/// Initialized on demand (on distributed storages initialization) since Settings should be initialized
@ -588,7 +408,6 @@ void Context::copyFrom(const ContextPtr & other)
Context::~Context() = default;
InterserverIOHandler & Context::getInterserverIOHandler() { return shared->interserver_io_handler; }
std::unique_lock<std::recursive_mutex> Context::getLock() const
@ -605,21 +424,6 @@ const MergeList & Context::getMergeList() const { return shared->merge_list; }
ReplicatedFetchList & Context::getReplicatedFetchList() { return shared->replicated_fetch_list; }
const ReplicatedFetchList & Context::getReplicatedFetchList() const { return shared->replicated_fetch_list; }
void Context::enableNamedSessions()
{
shared->named_sessions.emplace();
}
std::shared_ptr<NamedSession>
Context::acquireNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check)
{
if (!shared->named_sessions)
throw Exception("Support for named sessions is not enabled", ErrorCodes::NOT_IMPLEMENTED);
return shared->named_sessions->acquireSession(session_id, shared_from_this(), timeout, session_check);
}
String Context::resolveDatabase(const String & database_name) const
{
String res = database_name.empty() ? getCurrentDatabase() : database_name;
@ -785,48 +589,45 @@ ConfigurationPtr Context::getUsersConfig()
}
void Context::setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address)
void Context::authenticate(const String & name, const String & password, const Poco::Net::SocketAddress & address)
{
auto lock = getLock();
authenticate(BasicCredentials(name, password), address);
}
void Context::authenticate(const Credentials & credentials, const Poco::Net::SocketAddress & address)
{
auto authenticated_user_id = getAccessControlManager().login(credentials, address.host());
client_info.current_user = credentials.getUserName();
client_info.current_address = address;
#if defined(ARCADIA_BUILD)
/// This is harmful field that is used only in foreign "Arcadia" build.
client_info.current_password.clear();
if (const auto * basic_credentials = dynamic_cast<const BasicCredentials *>(&credentials))
client_info.current_password = basic_credentials->getPassword();
#endif
/// Find a user with such name and check the credentials.
auto new_user_id = getAccessControlManager().login(credentials, address.host());
auto new_access = getAccessControlManager().getContextAccess(
new_user_id, /* current_roles = */ {}, /* use_default_roles = */ true,
settings, current_database, client_info);
setUser(authenticated_user_id);
}
user_id = new_user_id;
access = std::move(new_access);
void Context::setUser(const UUID & user_id_)
{
auto lock = getLock();
user_id = user_id_;
access = getAccessControlManager().getContextAccess(
user_id_, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info);
auto user = access->getUser();
current_roles = std::make_shared<std::vector<UUID>>(user->granted_roles.findGranted(user->default_roles));
if (!user->default_database.empty())
setCurrentDatabase(user->default_database);
auto default_profile_info = access->getDefaultProfileInfo();
settings_constraints_and_current_profiles = default_profile_info->getConstraintsAndProfileIDs();
applySettingsChanges(default_profile_info->settings);
}
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address)
{
setUser(BasicCredentials(name, password), address);
}
void Context::setUserWithoutCheckingPassword(const String & name, const Poco::Net::SocketAddress & address)
{
setUser(AlwaysAllowCredentials(name), address);
if (!user->default_database.empty())
setCurrentDatabase(user->default_database);
}
std::shared_ptr<const User> Context::getUser() const
@ -834,12 +635,6 @@ std::shared_ptr<const User> Context::getUser() const
return getAccess()->getUser();
}
void Context::setQuotaKey(String quota_key_)
{
auto lock = getLock();
client_info.quota_key = std::move(quota_key_);
}
String Context::getUserName() const
{
return getAccess()->getUserName();
@ -852,6 +647,13 @@ std::optional<UUID> Context::getUserID() const
}
void Context::setQuotaKey(String quota_key_)
{
auto lock = getLock();
client_info.quota_key = std::move(quota_key_);
}
void Context::setCurrentRoles(const std::vector<UUID> & current_roles_)
{
auto lock = getLock();
@ -933,10 +735,13 @@ ASTPtr Context::getRowPolicyCondition(const String & database, const String & ta
void Context::setInitialRowPolicy()
{
auto lock = getLock();
auto initial_user_id = getAccessControlManager().find<User>(client_info.initial_user);
initial_row_policy = nullptr;
if (initial_user_id)
initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {});
if (client_info.initial_user == client_info.current_user)
return;
auto initial_user_id = getAccessControlManager().find<User>(client_info.initial_user);
if (!initial_user_id)
return;
initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {});
}
@ -1377,6 +1182,9 @@ void Context::setCurrentQueryId(const String & query_id)
}
client_info.current_query_id = query_id_to_set;
if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
client_info.initial_query_id = client_info.current_query_id;
}
void Context::killCurrentQuery()

View File

@ -14,21 +14,16 @@
#include <Common/MultiVersion.h>
#include <Common/OpenTelemetryTraceContext.h>
#include <Common/RemoteHostFilter.h>
#include <Common/ThreadPool.h>
#include <common/types.h>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <thread>
namespace Poco::Net { class IPAddress; }
@ -67,6 +62,7 @@ class ProcessList;
class QueryStatus;
class Macros;
struct Progress;
struct FileProgress;
class Clusters;
class QueryLog;
class QueryThreadLog;
@ -107,6 +103,7 @@ using StoragePolicySelectorPtr = std::shared_ptr<const StoragePolicySelector>;
struct PartUUIDs;
using PartUUIDsPtr = std::shared_ptr<PartUUIDs>;
class KeeperStorageDispatcher;
class Session;
class IOutputFormat;
using OutputFormatPtr = std::shared_ptr<IOutputFormat>;
@ -287,8 +284,6 @@ public:
OpenTelemetryTraceContext query_trace_context;
private:
friend class NamedSessions;
using SampleBlockCache = std::unordered_map<std::string, Block>;
mutable SampleBlockCache sample_block_cache;
@ -367,23 +362,21 @@ public:
void setUsersConfig(const ConfigurationPtr & config);
ConfigurationPtr getUsersConfig();
/// Sets the current user, checks the credentials and that the specified host is allowed.
/// Must be called before getClientInfo() can be called.
void setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address);
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address);
/// Sets the current user, checks the credentials and that the specified address is allowed to connect from.
/// The function throws an exception if there is no such user or password is wrong.
void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address);
void authenticate(const Credentials & credentials, const Poco::Net::SocketAddress & address);
/// Sets the current user, *does not check the password/credentials and that the specified host is allowed*.
/// Must be called before getClientInfo.
///
/// (Used only internally in cluster, if the secret matches)
void setUserWithoutCheckingPassword(const String & name, const Poco::Net::SocketAddress & address);
void setQuotaKey(String quota_key_);
/// Sets the current user assuming that he/she is already authenticated.
/// WARNING: This function doesn't check password! Don't use until it's necessary!
void setUser(const UUID & user_id_);
UserPtr getUser() const;
String getUserName() const;
std::optional<UUID> getUserID() const;
void setQuotaKey(String quota_key_);
void setCurrentRoles(const std::vector<UUID> & current_roles_);
void setCurrentRolesDefault();
boost::container::flat_set<UUID> getCurrentRoles() const;
@ -591,12 +584,6 @@ public:
std::optional<UInt16> getTCPPortSecure() const;
/// Allow to use named sessions. The thread will be run to cleanup sessions after timeout has expired.
/// The method must be called at the server startup.
void enableNamedSessions();
std::shared_ptr<NamedSession> acquireNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check);
/// For methods below you may need to acquire the context lock by yourself.
ContextMutablePtr getQueryContext() const;
@ -852,32 +839,6 @@ private:
StoragePolicySelectorPtr getStoragePolicySelector(std::lock_guard<std::mutex> & lock) const;
DiskSelectorPtr getDiskSelector(std::lock_guard<std::mutex> & /* lock */) const;
/// If the password is not set, the password will not be checked
void setUserImpl(const String & name, const std::optional<String> & password, const Poco::Net::SocketAddress & address);
};
class NamedSessions;
/// User name and session identifier. Named sessions are local to users.
using NamedSessionKey = std::pair<String, String>;
/// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests.
struct NamedSession
{
NamedSessionKey key;
UInt64 close_cycle = 0;
ContextMutablePtr context;
std::chrono::steady_clock::duration timeout;
NamedSessions & parent;
NamedSession(NamedSessionKey key_, ContextPtr context_, std::chrono::steady_clock::duration timeout_, NamedSessions & parent_)
: key(key_), context(Context::createCopy(context_)), timeout(timeout_), parent(parent_)
{
}
void release();
};
}

View File

@ -0,0 +1,414 @@
#include <Interpreters/Session.h>
#include <Access/AccessControlManager.h>
#include <Access/Credentials.h>
#include <Access/ContextAccess.h>
#include <Access/User.h>
#include <Common/Exception.h>
#include <Common/ThreadPool.h>
#include <Common/setThreadName.h>
#include <Interpreters/Context.h>
#include <atomic>
#include <condition_variable>
#include <deque>
#include <mutex>
#include <unordered_map>
#include <vector>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int SESSION_NOT_FOUND;
extern const int SESSION_IS_LOCKED;
}
class NamedSessionsStorage;
/// User ID and session identifier. Named sessions are local to users.
using NamedSessionKey = std::pair<UUID, String>;
/// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests.
struct NamedSessionData
{
NamedSessionKey key;
UInt64 close_cycle = 0;
ContextMutablePtr context;
std::chrono::steady_clock::duration timeout;
NamedSessionsStorage & parent;
NamedSessionData(NamedSessionKey key_, ContextPtr context_, std::chrono::steady_clock::duration timeout_, NamedSessionsStorage & parent_)
: key(std::move(key_)), context(Context::createCopy(context_)), timeout(timeout_), parent(parent_)
{}
void release();
};
class NamedSessionsStorage
{
public:
using Key = NamedSessionKey;
~NamedSessionsStorage()
{
try
{
{
std::lock_guard lock{mutex};
quit = true;
}
cond.notify_one();
thread.join();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
/// Find existing session or create a new.
std::pair<std::shared_ptr<NamedSessionData>, bool> acquireSession(
const ContextPtr & global_context,
const UUID & user_id,
const String & session_id,
std::chrono::steady_clock::duration timeout,
bool throw_if_not_found)
{
std::unique_lock lock(mutex);
Key key{user_id, session_id};
auto it = sessions.find(key);
if (it == sessions.end())
{
if (throw_if_not_found)
throw Exception("Session not found.", ErrorCodes::SESSION_NOT_FOUND);
/// Create a new session from current context.
auto context = Context::createCopy(global_context);
it = sessions.insert(std::make_pair(key, std::make_shared<NamedSessionData>(key, context, timeout, *this))).first;
const auto & session = it->second;
return {session, true};
}
else
{
/// Use existing session.
const auto & session = it->second;
if (!session.unique())
throw Exception("Session is locked by a concurrent client.", ErrorCodes::SESSION_IS_LOCKED);
return {session, false};
}
}
void releaseSession(NamedSessionData & session)
{
std::unique_lock lock(mutex);
scheduleCloseSession(session, lock);
}
private:
class SessionKeyHash
{
public:
size_t operator()(const Key & key) const
{
SipHash hash;
hash.update(key.first);
hash.update(key.second);
return hash.get64();
}
};
/// TODO it's very complicated. Make simple std::map with time_t or boost::multi_index.
using Container = std::unordered_map<Key, std::shared_ptr<NamedSessionData>, SessionKeyHash>;
using CloseTimes = std::deque<std::vector<Key>>;
Container sessions;
CloseTimes close_times;
std::chrono::steady_clock::duration close_interval = std::chrono::seconds(1);
std::chrono::steady_clock::time_point close_cycle_time = std::chrono::steady_clock::now();
UInt64 close_cycle = 0;
void scheduleCloseSession(NamedSessionData & session, std::unique_lock<std::mutex> &)
{
/// Push it on a queue of sessions to close, on a position corresponding to the timeout.
/// (timeout is measured from current moment of time)
const UInt64 close_index = session.timeout / close_interval + 1;
const auto new_close_cycle = close_cycle + close_index;
if (session.close_cycle != new_close_cycle)
{
session.close_cycle = new_close_cycle;
if (close_times.size() < close_index + 1)
close_times.resize(close_index + 1);
close_times[close_index].emplace_back(session.key);
}
}
void cleanThread()
{
setThreadName("SessionCleaner");
std::unique_lock lock{mutex};
while (true)
{
auto interval = closeSessions(lock);
if (cond.wait_for(lock, interval, [this]() -> bool { return quit; }))
break;
}
}
/// Close sessions, that has been expired. Returns how long to wait for next session to be expired, if no new sessions will be added.
std::chrono::steady_clock::duration closeSessions(std::unique_lock<std::mutex> & lock)
{
const auto now = std::chrono::steady_clock::now();
/// The time to close the next session did not come
if (now < close_cycle_time)
return close_cycle_time - now; /// Will sleep until it comes.
const auto current_cycle = close_cycle;
++close_cycle;
close_cycle_time = now + close_interval;
if (close_times.empty())
return close_interval;
auto & sessions_to_close = close_times.front();
for (const auto & key : sessions_to_close)
{
const auto session = sessions.find(key);
if (session != sessions.end() && session->second->close_cycle <= current_cycle)
{
if (!session->second.unique())
{
/// Skip but move it to close on the next cycle.
session->second->timeout = std::chrono::steady_clock::duration{0};
scheduleCloseSession(*session->second, lock);
}
else
sessions.erase(session);
}
}
close_times.pop_front();
return close_interval;
}
std::mutex mutex;
std::condition_variable cond;
std::atomic<bool> quit{false};
ThreadFromGlobalPool thread{&NamedSessionsStorage::cleanThread, this};
};
void NamedSessionData::release()
{
parent.releaseSession(*this);
}
std::optional<NamedSessionsStorage> Session::named_sessions = std::nullopt;
void Session::startupNamedSessions()
{
named_sessions.emplace();
}
Session::Session(const ContextPtr & global_context_, ClientInfo::Interface interface_)
: global_context(global_context_)
{
prepared_client_info.emplace();
prepared_client_info->interface = interface_;
}
Session::Session(Session &&) = default;
Session::~Session()
{
/// Early release a NamedSessionData.
if (named_session)
named_session->release();
}
Authentication::Type Session::getAuthenticationType(const String & user_name) const
{
return global_context->getAccessControlManager().read<User>(user_name)->authentication.getType();
}
Authentication::Digest Session::getPasswordDoubleSHA1(const String & user_name) const
{
return global_context->getAccessControlManager().read<User>(user_name)->authentication.getPasswordDoubleSHA1();
}
void Session::authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address)
{
authenticate(BasicCredentials{user_name, password}, address);
}
void Session::authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_)
{
if (session_context)
throw Exception("If there is a session context it must be created after authentication", ErrorCodes::LOGICAL_ERROR);
user_id = global_context->getAccessControlManager().login(credentials_, address_.host());
prepared_client_info->current_user = credentials_.getUserName();
prepared_client_info->current_address = address_;
#if defined(ARCADIA_BUILD)
/// This is harmful field that is used only in foreign "Arcadia" build.
if (const auto * basic_credentials = dynamic_cast<const BasicCredentials *>(&credentials_))
session_client_info->current_password = basic_credentials->getPassword();
#endif
}
ClientInfo & Session::getClientInfo()
{
return session_context ? session_context->getClientInfo() : *prepared_client_info;
}
const ClientInfo & Session::getClientInfo() const
{
return session_context ? session_context->getClientInfo() : *prepared_client_info;
}
ContextMutablePtr Session::makeSessionContext()
{
if (session_context)
throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR);
if (query_context_created)
throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR);
/// Make a new session context.
ContextMutablePtr new_session_context;
new_session_context = Context::createCopy(global_context);
new_session_context->makeSessionContext();
/// Copy prepared client info to the new session context.
auto & res_client_info = new_session_context->getClientInfo();
res_client_info = std::move(prepared_client_info).value();
prepared_client_info.reset();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id)
new_session_context->setUser(*user_id);
/// Session context is ready.
session_context = new_session_context;
user = session_context->getUser();
return session_context;
}
ContextMutablePtr Session::makeSessionContext(const String & session_id_, std::chrono::steady_clock::duration timeout_, bool session_check_)
{
if (session_context)
throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR);
if (query_context_created)
throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR);
if (!named_sessions)
throw Exception("Support for named sessions is not enabled", ErrorCodes::LOGICAL_ERROR);
/// Make a new session context OR
/// if the `session_id` and `user_id` were used before then just get a previously created session context.
std::shared_ptr<NamedSessionData> new_named_session;
bool new_named_session_created = false;
std::tie(new_named_session, new_named_session_created)
= named_sessions->acquireSession(global_context, user_id.value_or(UUID{}), session_id_, timeout_, session_check_);
auto new_session_context = new_named_session->context;
new_session_context->makeSessionContext();
/// Copy prepared client info to the session context, no matter it's been just created or not.
/// If we continue using a previously created session context found by session ID
/// it's necessary to replace the client info in it anyway, because it contains actual connection information (client address, etc.)
auto & res_client_info = new_session_context->getClientInfo();
res_client_info = std::move(prepared_client_info).value();
prepared_client_info.reset();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id && !new_session_context->getUser())
new_session_context->setUser(*user_id);
/// Session context is ready.
session_context = new_session_context;
session_id = session_id_;
named_session = new_named_session;
named_session_created = new_named_session_created;
user = session_context->getUser();
return session_context;
}
ContextMutablePtr Session::makeQueryContext(const ClientInfo & query_client_info) const
{
return makeQueryContextImpl(&query_client_info, nullptr);
}
ContextMutablePtr Session::makeQueryContext(ClientInfo && query_client_info) const
{
return makeQueryContextImpl(nullptr, &query_client_info);
}
ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const
{
/// We can create a query context either from a session context or from a global context.
bool from_session_context = static_cast<bool>(session_context);
/// Create a new query context.
ContextMutablePtr query_context = Context::createCopy(from_session_context ? session_context : global_context);
query_context->makeQueryContext();
/// Copy the specified client info to the new query context.
auto & res_client_info = query_context->getClientInfo();
if (client_info_to_move)
res_client_info = std::move(*client_info_to_move);
else if (client_info_to_copy && (client_info_to_copy != &getClientInfo()))
res_client_info = *client_info_to_copy;
/// Copy current user's name and address if it was authenticated after query_client_info was initialized.
if (prepared_client_info && !prepared_client_info->current_user.empty())
{
res_client_info.current_user = prepared_client_info->current_user;
res_client_info.current_address = prepared_client_info->current_address;
#if defined(ARCADIA_BUILD)
res_client_info.current_password = prepared_client_info->current_password;
#endif
}
/// Set parameters of initial query.
if (res_client_info.query_kind == ClientInfo::QueryKind::NO_QUERY)
res_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
if (res_client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{
res_client_info.initial_user = res_client_info.current_user;
res_client_info.initial_address = res_client_info.current_address;
}
/// Sets that row policies from the initial user should be used too.
query_context->setInitialRowPolicy();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id && !query_context->getUser())
query_context->setUser(*user_id);
/// Query context is ready.
query_context_created = true;
user = query_context->getUser();
return query_context;
}
}

View File

@ -0,0 +1,90 @@
#pragma once
#include <Common/SettingsChanges.h>
#include <Access/Authentication.h>
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <chrono>
#include <memory>
#include <optional>
namespace Poco::Net { class SocketAddress; }
namespace DB
{
class Credentials;
class Authentication;
struct NamedSessionData;
class NamedSessionsStorage;
struct User;
using UserPtr = std::shared_ptr<const User>;
/** Represents user-session from the server perspective,
* basically it is just a smaller subset of Context API, simplifies Context management.
*
* Holds session context, facilitates acquisition of NamedSession and proper creation of query contexts.
*/
class Session
{
public:
/// Allow to use named sessions. The thread will be run to cleanup sessions after timeout has expired.
/// The method must be called at the server startup.
static void startupNamedSessions();
Session(const ContextPtr & global_context_, ClientInfo::Interface interface_);
Session(Session &&);
~Session();
Session(const Session &) = delete;
Session& operator=(const Session &) = delete;
/// Provides information about the authentication type of a specified user.
Authentication::Type getAuthenticationType(const String & user_name) const;
Authentication::Digest getPasswordDoubleSHA1(const String & user_name) const;
/// Sets the current user, checks the credentials and that the specified address is allowed to connect from.
/// The function throws an exception if there is no such user or password is wrong.
void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address);
void authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_);
/// Returns a reference to session ClientInfo.
ClientInfo & getClientInfo();
const ClientInfo & getClientInfo() const;
/// Makes a session context, can be used one or zero times.
/// The function also assigns an user to this context.
ContextMutablePtr makeSessionContext();
ContextMutablePtr makeSessionContext(const String & session_id_, std::chrono::steady_clock::duration timeout_, bool session_check_);
ContextMutablePtr sessionContext() { return session_context; }
ContextPtr sessionContext() const { return session_context; }
/// Makes a query context, can be used multiple times, with or without makeSession() called earlier.
/// The query context will be created from a copy of a session context if it exists, or from a copy of
/// a global context otherwise. In the latter case the function also assigns an user to this context.
ContextMutablePtr makeQueryContext() const { return makeQueryContext(getClientInfo()); }
ContextMutablePtr makeQueryContext(const ClientInfo & query_client_info) const;
ContextMutablePtr makeQueryContext(ClientInfo && query_client_info) const;
private:
ContextMutablePtr makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const;
const ContextPtr global_context;
/// ClientInfo that will be copied to a session context when it's created.
std::optional<ClientInfo> prepared_client_info;
mutable UserPtr user;
std::optional<UUID> user_id;
ContextMutablePtr session_context;
mutable bool query_context_created = false;
String session_id;
std::shared_ptr<NamedSessionData> named_session;
bool named_session_created = false;
static std::optional<NamedSessionsStorage> named_sessions;
};
}

View File

@ -64,7 +64,6 @@ void MySQLOutputFormat::initialize()
}
}
void MySQLOutputFormat::consume(Chunk chunk)
{
initialize();

View File

@ -13,6 +13,7 @@
#include <Interpreters/Context.h>
#include <Interpreters/InternalTextLogsQueue.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Session.h>
#include <IO/ConcatReadBuffer.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
@ -54,7 +55,6 @@ namespace ErrorCodes
extern const int NETWORK_ERROR;
extern const int NO_DATA_TO_INSERT;
extern const int SUPPORT_IS_DISABLED;
extern const int UNKNOWN_DATABASE;
}
namespace
@ -560,7 +560,7 @@ namespace
IServer & iserver;
Poco::Logger * log = nullptr;
std::shared_ptr<NamedSession> session;
std::optional<Session> session;
ContextMutablePtr query_context;
std::optional<CurrentThread::QueryScope> query_scope;
String query_text;
@ -689,34 +689,20 @@ namespace
password = "";
}
/// Create context.
query_context = Context::createCopy(iserver.context());
/// Authentication.
query_context->setUser(user, password, user_address);
query_context->setCurrentQueryId(query_info.query_id());
if (!quota_key.empty())
query_context->setQuotaKey(quota_key);
session.emplace(iserver.context(), ClientInfo::Interface::GRPC);
session->authenticate(user, password, user_address);
session->getClientInfo().quota_key = quota_key;
/// The user could specify session identifier and session timeout.
/// It allows to modify settings, create temporary tables and reuse them in subsequent requests.
if (!query_info.session_id().empty())
{
session = query_context->acquireNamedSession(
session->makeSessionContext(
query_info.session_id(), getSessionTimeout(query_info, iserver.config()), query_info.session_check());
query_context = Context::createCopy(session->context);
query_context->setSessionContext(session->context);
}
query_scope.emplace(query_context);
/// Set client info.
ClientInfo & client_info = query_context->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.interface = ClientInfo::Interface::GRPC;
client_info.initial_user = client_info.current_user;
client_info.initial_query_id = client_info.current_query_id;
client_info.initial_address = client_info.current_address;
query_context = session->makeQueryContext();
/// Prepare settings.
SettingsChanges settings_changes;
@ -726,11 +712,14 @@ namespace
}
query_context->checkSettingsConstraints(settings_changes);
query_context->applySettingsChanges(settings_changes);
const Settings & settings = query_context->getSettingsRef();
query_context->setCurrentQueryId(query_info.query_id());
query_scope.emplace(query_context);
/// Prepare for sending exceptions and logs.
send_exception_with_stacktrace = query_context->getSettingsRef().calculate_text_stack_trace;
const auto client_logs_level = query_context->getSettingsRef().send_logs_level;
const Settings & settings = query_context->getSettingsRef();
send_exception_with_stacktrace = settings.calculate_text_stack_trace;
const auto client_logs_level = settings.send_logs_level;
if (client_logs_level != LogsLevel::none)
{
logs_queue = std::make_shared<InternalTextLogsQueue>();
@ -741,14 +730,10 @@ namespace
/// Set the current database if specified.
if (!query_info.database().empty())
{
if (!DatabaseCatalog::instance().isDatabaseExist(query_info.database()))
throw Exception("Database " + query_info.database() + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE);
query_context->setCurrentDatabase(query_info.database());
}
/// The interactive delay will be used to show progress.
interactive_delay = query_context->getSettingsRef().interactive_delay;
interactive_delay = settings.interactive_delay;
query_context->setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); });
/// Parse the query.
@ -1254,8 +1239,6 @@ namespace
io = {};
query_scope.reset();
query_context.reset();
if (session)
session->release();
session.reset();
}

View File

@ -21,6 +21,7 @@
#include <Interpreters/Context.h>
#include <Interpreters/QueryParameterVisitor.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Session.h>
#include <Server/HTTPHandlerFactory.h>
#include <Server/HTTPHandlerRequestFilter.h>
#include <Server/IServer.h>
@ -261,6 +262,7 @@ void HTTPHandler::pushDelayedResults(Output & used_output)
HTTPHandler::HTTPHandler(IServer & server_, const std::string & name)
: server(server_)
, log(&Poco::Logger::get(name))
, default_settings(server.context()->getSettingsRef())
{
server_display_name = server.config().getString("display_name", getFQDNOrHostName());
}
@ -268,14 +270,10 @@ HTTPHandler::HTTPHandler(IServer & server_, const std::string & name)
/// We need d-tor to be present in this translation unit to make it play well with some
/// forward decls in the header. Other than that, the default d-tor would be OK.
HTTPHandler::~HTTPHandler()
{
(void)this;
}
HTTPHandler::~HTTPHandler() = default;
bool HTTPHandler::authenticateUser(
ContextMutablePtr context,
HTTPServerRequest & request,
HTMLForm & params,
HTTPServerResponse & response)
@ -352,7 +350,7 @@ bool HTTPHandler::authenticateUser(
else
{
if (!request_credentials)
request_credentials = request_context->makeGSSAcceptorContext();
request_credentials = server.context()->makeGSSAcceptorContext();
auto * gss_acceptor_context = dynamic_cast<GSSAcceptorContext *>(request_credentials.get());
if (!gss_acceptor_context)
@ -378,10 +376,7 @@ bool HTTPHandler::authenticateUser(
}
/// Set client info. It will be used for quota accounting parameters in 'setUser' method.
ClientInfo & client_info = context->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.interface = ClientInfo::Interface::HTTP;
ClientInfo & client_info = session->getClientInfo();
ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
if (request.getMethod() == HTTPServerRequest::HTTP_GET)
@ -393,10 +388,11 @@ bool HTTPHandler::authenticateUser(
client_info.http_user_agent = request.get("User-Agent", "");
client_info.http_referer = request.get("Referer", "");
client_info.forwarded_for = request.get("X-Forwarded-For", "");
client_info.quota_key = quota_key;
try
{
context->setUser(*request_credentials, request.clientAddress());
session->authenticate(*request_credentials, request.clientAddress());
}
catch (const Authentication::Require<BasicCredentials> & required_credentials)
{
@ -413,7 +409,7 @@ bool HTTPHandler::authenticateUser(
}
catch (const Authentication::Require<GSSAcceptorContext> & required_credentials)
{
request_credentials = request_context->makeGSSAcceptorContext();
request_credentials = server.context()->makeGSSAcceptorContext();
if (required_credentials.getRealm().empty())
response.set("WWW-Authenticate", "Negotiate");
@ -426,20 +422,11 @@ bool HTTPHandler::authenticateUser(
}
request_credentials.reset();
if (!quota_key.empty())
context->setQuotaKey(quota_key);
/// Query sent through HTTP interface is initial.
client_info.initial_user = client_info.current_user;
client_info.initial_address = client_info.current_address;
return true;
}
void HTTPHandler::processQuery(
ContextMutablePtr context,
HTTPServerRequest & request,
HTMLForm & params,
HTTPServerResponse & response,
@ -450,13 +437,11 @@ void HTTPHandler::processQuery(
LOG_TRACE(log, "Request URI: {}", request.getURI());
if (!authenticateUser(context, request, params, response))
if (!authenticateUser(request, params, response))
return; // '401 Unauthorized' response with 'Negotiate' has been sent at this point.
/// The user could specify session identifier and session timeout.
/// It allows to modify settings, create temporary tables and reuse them in subsequent requests.
std::shared_ptr<NamedSession> session;
String session_id;
std::chrono::steady_clock::duration session_timeout;
bool session_is_set = params.has("session_id");
@ -467,43 +452,30 @@ void HTTPHandler::processQuery(
session_id = params.get("session_id");
session_timeout = parseSessionTimeout(config, params);
std::string session_check = params.get("session_check", "");
session = context->acquireNamedSession(session_id, session_timeout, session_check == "1");
context->copyFrom(session->context); /// FIXME: maybe move this part to HandleRequest(), copyFrom() is used only here.
context->setSessionContext(session->context);
session->makeSessionContext(session_id, session_timeout, session_check == "1");
}
SCOPE_EXIT({
if (session)
session->release();
});
// Parse the OpenTelemetry traceparent header.
// Disable in Arcadia -- it interferes with the
// test_clickhouse.TestTracing.test_tracing_via_http_proxy[traceparent] test.
ClientInfo client_info = session->getClientInfo();
#if !defined(ARCADIA_BUILD)
if (request.has("traceparent"))
{
std::string opentelemetry_traceparent = request.get("traceparent");
std::string error;
if (!context->getClientInfo().client_trace_context.parseTraceparentHeader(
if (!client_info.client_trace_context.parseTraceparentHeader(
opentelemetry_traceparent, error))
{
throw Exception(ErrorCodes::BAD_REQUEST_PARAMETER,
"Failed to parse OpenTelemetry traceparent header '{}': {}",
opentelemetry_traceparent, error);
}
context->getClientInfo().client_trace_context.tracestate = request.get("tracestate", "");
client_info.client_trace_context.tracestate = request.get("tracestate", "");
}
#endif
// Set the query id supplied by the user, if any, and also update the OpenTelemetry fields.
context->setCurrentQueryId(params.get("query_id", request.get("X-ClickHouse-Query-Id", "")));
ClientInfo & client_info = context->getClientInfo();
client_info.initial_query_id = client_info.current_query_id;
auto context = session->makeQueryContext(std::move(client_info));
/// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
String http_response_compression_methods = request.get("Accept-Encoding", "");
@ -568,7 +540,7 @@ void HTTPHandler::processQuery(
if (buffer_until_eof)
{
const std::string tmp_path(context->getTemporaryVolume()->getDisk()->getPath());
const std::string tmp_path(server.context()->getTemporaryVolume()->getDisk()->getPath());
const std::string tmp_path_template(tmp_path + "http_buffers/");
auto create_tmp_disk_buffer = [tmp_path_template] (const WriteBufferPtr &)
@ -714,6 +686,9 @@ void HTTPHandler::processQuery(
context->checkSettingsConstraints(settings_changes);
context->applySettingsChanges(settings_changes);
// Set the query id supplied by the user, if any, and also update the OpenTelemetry fields.
context->setCurrentQueryId(params.get("query_id", request.get("X-ClickHouse-Query-Id", "")));
const auto & query = getQuery(request, params, context);
std::unique_ptr<ReadBuffer> in_param = std::make_unique<ReadBufferFromString>(query);
in = has_external_data ? std::move(in_param) : std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
@ -864,23 +839,10 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
setThreadName("HTTPHandler");
ThreadStatus thread_status;
SCOPE_EXIT({
// If there is no request_credentials instance waiting for the next round, then the request is processed,
// so no need to preserve request_context either.
// Needs to be performed with respect to the other destructors in the scope though.
if (!request_credentials)
request_context.reset();
});
if (!request_context)
{
// Context should be initialized before anything, for correct memory accounting.
request_context = Context::createCopy(server.context());
request_credentials.reset();
}
/// Cannot be set here, since query_id is unknown.
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::HTTP);
SCOPE_EXIT({ session.reset(); });
std::optional<CurrentThread::QueryScope> query_scope;
Output used_output;
/// In case of exception, send stack trace to client.
@ -894,7 +856,7 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
HTMLForm params(request_context->getSettingsRef(), request);
HTMLForm params(default_settings, request);
with_stacktrace = params.getParsed<bool>("stacktrace", false);
/// FIXME: maybe this check is already unnecessary.
@ -906,7 +868,7 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
ErrorCodes::HTTP_LENGTH_REQUIRED);
}
processQuery(request_context, request, params, response, used_output, query_scope);
processQuery(request, params, response, used_output, query_scope);
LOG_DEBUG(log, (request_credentials ? "Authentication in progress..." : "Done processing query"));
}
catch (...)

View File

@ -18,8 +18,10 @@ namespace Poco { class Logger; }
namespace DB
{
class Session;
class Credentials;
class IServer;
struct Settings;
class WriteBufferFromHTTPServerResponse;
using CompiledRegexPtr = std::shared_ptr<const re2::RE2>;
@ -71,25 +73,30 @@ private:
CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection};
// The request_context and the request_credentials instances may outlive a single request/response loop.
/// Reference to the immutable settings in the global context.
/// Those settings are used only to extract a http request's parameters.
/// See settings http_max_fields, http_max_field_name_size, http_max_field_value_size in HTMLForm.
const Settings & default_settings;
// session is reset at the end of each request/response.
std::unique_ptr<Session> session;
// The request_credential instance may outlive a single request/response loop.
// This happens only when the authentication mechanism requires more than a single request/response exchange (e.g., SPNEGO).
ContextMutablePtr request_context;
std::unique_ptr<Credentials> request_credentials;
// Returns true when the user successfully authenticated,
// the request_context instance will be configured accordingly, and the request_credentials instance will be dropped.
// the session instance will be configured accordingly, and the request_credentials instance will be dropped.
// Returns false when the user is not authenticated yet, and the 'Negotiate' response is sent,
// the request_context and request_credentials instances are preserved.
// the session and request_credentials instances are preserved.
// Throws an exception if authentication failed.
bool authenticateUser(
ContextMutablePtr context,
HTTPServerRequest & request,
HTMLForm & params,
HTTPServerResponse & response);
/// Also initializes 'used_output'.
void processQuery(
ContextMutablePtr context,
HTTPServerRequest & request,
HTMLForm & params,
HTTPServerResponse & response,

View File

@ -3,11 +3,12 @@
#include <limits>
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Core/NamesAndTypes.h>
#include <DataStreams/copyData.h>
#include <Interpreters/Session.h>
#include <Interpreters/executeQuery.h>
#include <IO/copyData.h>
#include <IO/LimitReadBuffer.h>
@ -18,9 +19,8 @@
#include <IO/ReadHelpers.h>
#include <Storages/IStorage.h>
#include <regex>
#include <Access/User.h>
#include <Access/AccessControlManager.h>
#include <Common/setThreadName.h>
#include <Core/MySQL/Authentication.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config_version.h>
@ -70,7 +70,6 @@ MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & so
, server(server_)
, log(&Poco::Logger::get("MySQLHandler"))
, connection_id(connection_id_)
, connection_context(Context::createCopy(server.context()))
, auth_plugin(new MySQLProtocol::Authentication::Native41())
{
server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
@ -87,11 +86,11 @@ void MySQLHandler::run()
{
setThreadName("MySQLHandler");
ThreadStatus thread_status;
connection_context->makeSessionContext();
connection_context->getClientInfo().interface = ClientInfo::Interface::MYSQL;
connection_context->setDefaultFormat("MySQLWire");
connection_context->getClientInfo().connection_id = connection_id;
connection_context->getClientInfo().query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::MYSQL);
SCOPE_EXIT({ session.reset(); });
session->getClientInfo().connection_id = connection_id;
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
@ -125,14 +124,12 @@ void MySQLHandler::run()
authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response);
connection_context->getClientInfo().initial_user = handshake_response.username;
try
{
session->makeSessionContext();
session->sessionContext()->setDefaultFormat("MySQLWire");
if (!handshake_response.database.empty())
connection_context->setCurrentDatabase(handshake_response.database);
connection_context->setCurrentQueryId(Poco::format("mysql:%lu", connection_id));
session->sessionContext()->setCurrentDatabase(handshake_response.database);
}
catch (const Exception & exc)
{
@ -249,15 +246,13 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
try
{
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
auto user = connection_context->getAccessControlManager().read<User>(user_name);
const DB::Authentication::Type user_auth_type = user->authentication.getType();
if (user_auth_type == DB::Authentication::SHA256_PASSWORD)
if (session->getAuthenticationType(user_name) == DB::Authentication::SHA256_PASSWORD)
{
authPluginSSL();
}
std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
auth_plugin->authenticate(user_name, auth_response, connection_context, packet_endpoint, secure_connection, socket().peerAddress());
auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress());
}
catch (const Exception & exc)
{
@ -273,7 +268,7 @@ void MySQLHandler::comInitDB(ReadBuffer & payload)
String database;
readStringUntilEOF(database, payload);
LOG_DEBUG(log, "Setting current database to {}", database);
connection_context->setCurrentDatabase(database);
session->sessionContext()->setCurrentDatabase(database);
packet_endpoint->sendPacket(OKPacket(0, client_capabilities, 0, 0, 1), true);
}
@ -281,8 +276,9 @@ void MySQLHandler::comFieldList(ReadBuffer & payload)
{
ComFieldList packet;
packet.readPayloadWithUnpacked(payload);
String database = connection_context->getCurrentDatabase();
StoragePtr table_ptr = DatabaseCatalog::instance().getTable({database, packet.table}, connection_context);
const auto session_context = session->sessionContext();
String database = session_context->getCurrentDatabase();
StoragePtr table_ptr = DatabaseCatalog::instance().getTable({database, packet.table}, session_context);
auto metadata_snapshot = table_ptr->getInMemoryMetadataPtr();
for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll())
{
@ -329,7 +325,9 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
ReadBufferFromString replacement(replacement_query);
auto query_context = Context::createCopy(connection_context);
auto query_context = session->makeQueryContext();
query_context->setCurrentQueryId(Poco::format("mysql:%lu", connection_id));
CurrentThread::QueryScope query_scope{query_context};
std::atomic<size_t> affected_rows {0};
auto prev = query_context->getProgressCallback();
@ -341,8 +339,6 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
affected_rows += progress.written_rows;
});
CurrentThread::QueryScope query_scope{query_context};
FormatSettings format_settings;
format_settings.mysql_wire.client_capabilities = client_capabilities;
format_settings.mysql_wire.max_packet_size = max_packet_size;

View File

@ -17,6 +17,8 @@
# include <Poco/Net/SecureStreamSocket.h>
#endif
#include <memory>
namespace CurrentMetrics
{
extern const Metric MySQLConnection;
@ -61,7 +63,7 @@ protected:
uint8_t sequence_id = 0;
MySQLProtocol::PacketEndpointPtr packet_endpoint;
ContextMutablePtr connection_context;
std::unique_ptr<Session> session;
using ReplacementFn = std::function<String(const String & query)>;
using Replacements = std::unordered_map<std::string, ReplacementFn>;

View File

@ -2,6 +2,7 @@
#include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <Interpreters/Context.h>
#include <Interpreters/executeQuery.h>
#include "PostgreSQLHandler.h"
#include <Parsers/parseQuery.h>
@ -33,7 +34,6 @@ PostgreSQLHandler::PostgreSQLHandler(
std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> & auth_methods_)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, connection_context(Context::createCopy(server.context()))
, ssl_enabled(ssl_enabled_)
, connection_id(connection_id_)
, authentication_manager(auth_methods_)
@ -52,10 +52,9 @@ void PostgreSQLHandler::run()
{
setThreadName("PostgresHandler");
ThreadStatus thread_status;
connection_context->makeSessionContext();
connection_context->getClientInfo().interface = ClientInfo::Interface::POSTGRESQL;
connection_context->setDefaultFormat("PostgreSQLWire");
connection_context->getClientInfo().query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::POSTGRESQL);
SCOPE_EXIT({ session.reset(); });
try
{
@ -123,18 +122,15 @@ bool PostgreSQLHandler::startup()
}
std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> start_up_msg = receiveStartupMessage(payload_size);
authentication_manager.authenticate(start_up_msg->user, connection_context, *message_transport, socket().peerAddress());
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<Int32> dis(0, INT32_MAX);
secret_key = dis(gen);
const auto & user_name = start_up_msg->user;
authentication_manager.authenticate(user_name, *session, *message_transport, socket().peerAddress());
try
{
session->makeSessionContext();
session->sessionContext()->setDefaultFormat("PostgreSQLWire");
if (!start_up_msg->database.empty())
connection_context->setCurrentDatabase(start_up_msg->database);
connection_context->setCurrentQueryId(Poco::format("postgres:%d:%d", connection_id, secret_key));
session->sessionContext()->setCurrentDatabase(start_up_msg->database);
}
catch (const Exception & exc)
{
@ -214,16 +210,15 @@ void PostgreSQLHandler::sendParameterStatusData(PostgreSQLProtocol::Messaging::S
void PostgreSQLHandler::cancelRequest()
{
connection_context->setCurrentQueryId("");
connection_context->setDefaultFormat("Null");
std::unique_ptr<PostgreSQLProtocol::Messaging::CancelRequest> msg =
message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::CancelRequest>(8);
String query = Poco::format("KILL QUERY WHERE query_id = 'postgres:%d:%d'", msg->process_id, msg->secret_key);
ReadBufferFromString replacement(query);
executeQuery(replacement, *out, true, connection_context, {});
auto query_context = session->makeQueryContext();
query_context->setCurrentQueryId("");
executeQuery(replacement, *out, true, query_context, {});
}
inline std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> PostgreSQLHandler::receiveStartupMessage(int payload_size)
@ -269,18 +264,25 @@ void PostgreSQLHandler::processQuery()
return;
}
const auto & settings = connection_context->getSettingsRef();
const auto & settings = session->sessionContext()->getSettingsRef();
std::vector<String> queries;
auto parse_res = splitMultipartQuery(query->query, queries, settings.max_query_size, settings.max_parser_depth);
if (!parse_res.second)
throw Exception("Cannot parse and execute the following part of query: " + String(parse_res.first), ErrorCodes::SYNTAX_ERROR);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<Int32> dis(0, INT32_MAX);
for (const auto & spl_query : queries)
{
/// FIXME why do we execute all queries in a single connection context?
CurrentThread::QueryScope query_scope{connection_context};
secret_key = dis(gen);
auto query_context = session->makeQueryContext();
query_context->setCurrentQueryId(Poco::format("postgres:%d:%d", connection_id, secret_key));
CurrentThread::QueryScope query_scope{query_context};
ReadBufferFromString read_buf(spl_query);
executeQuery(read_buf, *out, false, connection_context, {});
executeQuery(read_buf, *out, false, query_context, {});
PostgreSQLProtocol::Messaging::CommandComplete::Command command =
PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(spl_query);

View File

@ -18,6 +18,8 @@ namespace CurrentMetrics
namespace DB
{
class Session;
/** PostgreSQL wire protocol implementation.
* For more info see https://www.postgresql.org/docs/current/protocol.html
*/
@ -37,7 +39,7 @@ private:
Poco::Logger * log = &Poco::Logger::get("PostgreSQLHandler");
IServer & server;
ContextMutablePtr connection_context;
std::unique_ptr<Session> session;
bool ssl_enabled = false;
Int32 connection_id = 0;
Int32 secret_key = 0;

View File

@ -24,10 +24,12 @@
#include <Interpreters/TablesStatus.h>
#include <Interpreters/InternalTextLogsQueue.h>
#include <Interpreters/OpenTelemetrySpanLog.h>
#include <Interpreters/Session.h>
#include <Storages/StorageReplicatedMergeTree.h>
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
#include <Storages/StorageS3Cluster.h>
#include <Core/ExternalTable.h>
#include <Access/Credentials.h>
#include <Storages/ColumnDefault.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Compression/CompressionFactory.h>
@ -73,7 +75,6 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int ATTEMPT_TO_READ_AFTER_EOF;
extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT;
extern const int UNKNOWN_DATABASE;
extern const int UNKNOWN_EXCEPTION;
extern const int UNKNOWN_PACKET_FROM_CLIENT;
extern const int POCO_EXCEPTION;
@ -88,11 +89,10 @@ TCPHandler::TCPHandler(IServer & server_, const Poco::Net::StreamSocket & socket
, server(server_)
, parse_proxy_protocol(parse_proxy_protocol_)
, log(&Poco::Logger::get("TCPHandler"))
, connection_context(Context::createCopy(server.context()))
, query_context(Context::createCopy(server.context()))
, server_display_name(std::move(server_display_name_))
{
}
TCPHandler::~TCPHandler()
{
try
@ -112,16 +112,11 @@ void TCPHandler::runImpl()
setThreadName("TCPHandler");
ThreadStatus thread_status;
connection_context = Context::createCopy(server.context());
connection_context->makeSessionContext();
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::TCP);
extractConnectionSettingsFromContext(server.context());
/// These timeouts can be changed after receiving query.
auto global_receive_timeout = connection_context->getSettingsRef().receive_timeout;
auto global_send_timeout = connection_context->getSettingsRef().send_timeout;
socket().setReceiveTimeout(global_receive_timeout);
socket().setSendTimeout(global_send_timeout);
socket().setReceiveTimeout(receive_timeout);
socket().setSendTimeout(send_timeout);
socket().setNoDelay(true);
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
@ -159,34 +154,27 @@ void TCPHandler::runImpl()
try
{
/// We try to send error information to the client.
sendException(e, connection_context->getSettingsRef().calculate_text_stack_trace);
sendException(e, send_exception_with_stack_trace);
}
catch (...) {}
throw;
}
/// When connecting, the default database can be specified.
if (!default_database.empty())
{
if (!DatabaseCatalog::instance().isDatabaseExist(default_database))
{
Exception e("Database " + backQuote(default_database) + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE);
LOG_ERROR(log, getExceptionMessage(e, true));
sendException(e, connection_context->getSettingsRef().calculate_text_stack_trace);
return;
}
connection_context->setCurrentDatabase(default_database);
}
Settings connection_settings = connection_context->getSettings();
UInt64 idle_connection_timeout = connection_settings.idle_connection_timeout;
UInt64 poll_interval = connection_settings.poll_interval;
sendHello();
connection_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); });
if (!is_interserver_mode) /// In interserver mode queries are executed without a session context.
{
session->makeSessionContext();
/// If session created, then settings in session context has been updated.
/// So it's better to update the connection settings for flexibility.
extractConnectionSettingsFromContext(session->sessionContext());
/// When connecting, the default database could be specified.
if (!default_database.empty())
session->sessionContext()->setCurrentDatabase(default_database);
}
while (true)
{
@ -208,9 +196,6 @@ void TCPHandler::runImpl()
if (server.isCancelled() || in->eof())
break;
/// Set context of request.
query_context = Context::createCopy(connection_context);
Stopwatch watch;
state.reset();
@ -223,8 +208,6 @@ void TCPHandler::runImpl()
std::optional<DB::Exception> exception;
bool network_error = false;
bool send_exception_with_stack_trace = true;
try
{
/// If a user passed query-local timeouts, reset socket to initial state at the end of the query
@ -237,23 +220,22 @@ void TCPHandler::runImpl()
if (!receivePacket())
continue;
/** If Query received, then settings in query_context has been updated
* So, update some other connection settings, for flexibility.
*/
{
const Settings & settings = query_context->getSettingsRef();
idle_connection_timeout = settings.idle_connection_timeout;
poll_interval = settings.poll_interval;
}
/** If part_uuids got received in previous packet, trying to read again.
*/
if (state.empty() && state.part_uuids && !receivePacket())
if (state.empty() && state.part_uuids_to_ignore && !receivePacket())
continue;
query_scope.emplace(query_context);
send_exception_with_stack_trace = query_context->getSettingsRef().calculate_text_stack_trace;
/// If query received, then settings in query_context has been updated.
/// So it's better to update the connection settings for flexibility.
extractConnectionSettingsFromContext(query_context);
/// Sync timeouts on client and server during current query to avoid dangling queries on server
/// NOTE: We use send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter),
/// because send_timeout is client-side setting which has opposite meaning on the server side.
/// NOTE: these settings are applied only for current connection (not for distributed tables' connections)
state.timeout_setter = std::make_unique<TimeoutSetter>(socket(), receive_timeout, send_timeout);
/// Should we send internal logs to client?
const auto client_logs_level = query_context->getSettingsRef().send_logs_level;
@ -266,20 +248,18 @@ void TCPHandler::runImpl()
CurrentThread::setFatalErrorCallback([this]{ sendLogs(); });
}
query_context->setExternalTablesInitializer([&connection_settings, this] (ContextPtr context)
query_context->setExternalTablesInitializer([this] (ContextPtr context)
{
if (context != query_context)
throw Exception("Unexpected context in external tables initializer", ErrorCodes::LOGICAL_ERROR);
/// Get blocks of temporary tables
readData(connection_settings);
readData();
/// Reset the input stream, as we received an empty block while receiving external table data.
/// So, the stream has been marked as cancelled and we can't read from it anymore.
state.block_in.reset();
state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker.
state.temporary_tables_read = true;
});
/// Send structure of columns to client for function input()
@ -303,15 +283,12 @@ void TCPHandler::runImpl()
sendData(state.input_header);
});
query_context->setInputBlocksReaderCallback([&connection_settings, this] (ContextPtr context) -> Block
query_context->setInputBlocksReaderCallback([this] (ContextPtr context) -> Block
{
if (context != query_context)
throw Exception("Unexpected context in InputBlocksReader", ErrorCodes::LOGICAL_ERROR);
size_t poll_interval_ms;
int receive_timeout;
std::tie(poll_interval_ms, receive_timeout) = getReadTimeouts(connection_settings);
if (!readDataNext(poll_interval_ms, receive_timeout))
if (!readDataNext())
{
state.block_in.reset();
state.maybe_compressed_in.reset();
@ -334,15 +311,13 @@ void TCPHandler::runImpl()
/// Processing Query
state.io = executeQuery(state.query, query_context, false, state.stage, may_have_embedded_data);
unknown_packet_in_send_data = query_context->getSettingsRef().unknown_packet_in_send_data;
after_check_cancelled.restart();
after_send_progress.restart();
if (state.io.out)
{
state.need_receive_data_for_insert = true;
processInsertQuery(connection_settings);
processInsertQuery();
}
else if (state.need_receive_data_for_input) // It implies pipeline execution
{
@ -458,16 +433,17 @@ void TCPHandler::runImpl()
try
{
if (exception && !state.temporary_tables_read)
query_context->initializeExternalTablesIfSet();
/// A query packet is always followed by one or more data packets.
/// If some of those data packets are left, try to skip them.
if (exception && !state.empty() && !state.read_all_data)
skipData();
}
catch (...)
{
network_error = true;
LOG_WARNING(log, "Can't read external tables after query failure.");
LOG_WARNING(log, "Can't skip data packets after query failure.");
}
try
{
/// QueryState should be cleared before QueryScope, since otherwise
@ -498,75 +474,94 @@ void TCPHandler::runImpl()
}
bool TCPHandler::readDataNext(size_t poll_interval, time_t receive_timeout)
void TCPHandler::extractConnectionSettingsFromContext(const ContextPtr & context)
{
const auto & settings = context->getSettingsRef();
send_exception_with_stack_trace = settings.calculate_text_stack_trace;
send_timeout = settings.send_timeout;
receive_timeout = settings.receive_timeout;
poll_interval = settings.poll_interval;
idle_connection_timeout = settings.idle_connection_timeout;
interactive_delay = settings.interactive_delay;
sleep_in_send_tables_status = settings.sleep_in_send_tables_status_ms;
unknown_packet_in_send_data = settings.unknown_packet_in_send_data;
sleep_in_receive_cancel = settings.sleep_in_receive_cancel_ms;
}
bool TCPHandler::readDataNext()
{
Stopwatch watch(CLOCK_MONOTONIC_COARSE);
/// Poll interval should not be greater than receive_timeout
constexpr UInt64 min_timeout_ms = 5000; // 5 ms
UInt64 timeout_ms = std::max(min_timeout_ms, std::min(poll_interval * 1000000, static_cast<UInt64>(receive_timeout.totalMicroseconds())));
bool read_ok = false;
/// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down.
while (true)
{
if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(poll_interval))
if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(timeout_ms))
{
/// If client disconnected.
if (in->eof())
{
LOG_INFO(log, "Client has dropped the connection, cancel the query.");
state.is_connection_closed = true;
break;
}
/// We accept and process data.
read_ok = receivePacket();
break;
}
/// Do we need to shut down?
if (server.isCancelled())
return false;
break;
/** Have we waited for data for too long?
* If we periodically poll, the receive_timeout of the socket itself does not work.
* Therefore, an additional check is added.
*/
Float64 elapsed = watch.elapsedSeconds();
if (elapsed > static_cast<Float64>(receive_timeout))
if (elapsed > static_cast<Float64>(receive_timeout.totalSeconds()))
{
throw Exception(ErrorCodes::SOCKET_TIMEOUT,
"Timeout exceeded while receiving data from client. Waited for {} seconds, timeout is {} seconds.",
static_cast<size_t>(elapsed), receive_timeout);
static_cast<size_t>(elapsed), receive_timeout.totalSeconds());
}
}
/// If client disconnected.
if (in->eof())
{
LOG_INFO(log, "Client has dropped the connection, cancel the query.");
state.is_connection_closed = true;
return false;
}
if (read_ok)
sendLogs();
else
state.read_all_data = true;
/// We accept and process data. And if they are over, then we leave.
if (!receivePacket())
return false;
sendLogs();
return true;
return read_ok;
}
std::tuple<size_t, int> TCPHandler::getReadTimeouts(const Settings & connection_settings)
void TCPHandler::readData()
{
const auto receive_timeout = query_context->getSettingsRef().receive_timeout.value;
/// Poll interval should not be greater than receive_timeout
const size_t default_poll_interval = connection_settings.poll_interval * 1000000;
size_t current_poll_interval = static_cast<size_t>(receive_timeout.totalMicroseconds());
constexpr size_t min_poll_interval = 5000; // 5 ms
size_t poll_interval = std::max(min_poll_interval, std::min(default_poll_interval, current_poll_interval));
return std::make_tuple(poll_interval, receive_timeout.totalSeconds());
}
void TCPHandler::readData(const Settings & connection_settings)
{
auto [poll_interval, receive_timeout] = getReadTimeouts(connection_settings);
sendLogs();
while (readDataNext(poll_interval, receive_timeout))
while (readDataNext())
;
}
void TCPHandler::processInsertQuery(const Settings & connection_settings)
void TCPHandler::skipData()
{
state.skipping_data = true;
SCOPE_EXIT({ state.skipping_data = false; });
while (readDataNext())
;
}
void TCPHandler::processInsertQuery()
{
/** Made above the rest of the lines, so that in case of `writePrefix` function throws an exception,
* client receive exception before sending data.
@ -592,7 +587,7 @@ void TCPHandler::processInsertQuery(const Settings & connection_settings)
try
{
readData(connection_settings);
readData();
}
catch (...)
{
@ -631,7 +626,7 @@ void TCPHandler::processOrdinaryQuery()
break;
}
if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay)
if (after_send_progress.elapsed() / 1000 >= interactive_delay)
{
/// Some time passed.
after_send_progress.restart();
@ -640,7 +635,7 @@ void TCPHandler::processOrdinaryQuery()
sendLogs();
if (async_in.poll(query_context->getSettingsRef().interactive_delay / 1000))
if (async_in.poll(interactive_delay / 1000))
{
const auto block = async_in.read();
if (!block)
@ -695,7 +690,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors()
CurrentMetrics::Increment query_thread_metric_increment{CurrentMetrics::QueryThread};
Block block;
while (executor.pull(block, query_context->getSettingsRef().interactive_delay / 1000))
while (executor.pull(block, interactive_delay / 1000))
{
std::lock_guard lock(task_callback_mutex);
@ -706,7 +701,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors()
break;
}
if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay)
if (after_send_progress.elapsed() / 1000 >= interactive_delay)
{
/// Some time passed and there is a progress.
after_send_progress.restart();
@ -753,11 +748,13 @@ void TCPHandler::processTablesStatusRequest()
TablesStatusRequest request;
request.read(*in, client_tcp_protocol_version);
ContextPtr context_to_resolve_table_names = session->sessionContext() ? session->sessionContext() : server.context();
TablesStatusResponse response;
for (const QualifiedTableName & table_name: request.tables)
{
auto resolved_id = connection_context->tryResolveStorageID({table_name.database, table_name.table});
StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, connection_context);
auto resolved_id = context_to_resolve_table_names->tryResolveStorageID({table_name.database, table_name.table});
StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, context_to_resolve_table_names);
if (!table)
continue;
@ -777,11 +774,10 @@ void TCPHandler::processTablesStatusRequest()
writeVarUInt(Protocol::Server::TablesStatusResponse, *out);
/// For testing hedged requests
const Settings & settings = query_context->getSettingsRef();
if (settings.sleep_in_send_tables_status_ms.totalMilliseconds())
if (sleep_in_send_tables_status.totalMilliseconds())
{
out->next();
std::chrono::milliseconds ms(settings.sleep_in_send_tables_status_ms.totalMilliseconds());
std::chrono::milliseconds ms(sleep_in_send_tables_status.totalMilliseconds());
std::this_thread::sleep_for(ms);
}
@ -924,7 +920,7 @@ bool TCPHandler::receiveProxyHeader()
}
LOG_TRACE(log, "Forwarded client address from PROXY header: {}", forwarded_address);
connection_context->getClientInfo().forwarded_for = forwarded_address;
session->getClientInfo().forwarded_for = forwarded_address;
return true;
}
@ -973,14 +969,21 @@ void TCPHandler::receiveHello()
(!user.empty() ? ", user: " + user : "")
);
if (user != USER_INTERSERVER_MARKER)
{
connection_context->setUser(user, password, socket().peerAddress());
}
else
auto & client_info = session->getClientInfo();
client_info.client_name = client_name;
client_info.client_version_major = client_version_major;
client_info.client_version_minor = client_version_minor;
client_info.client_version_patch = client_version_patch;
client_info.client_tcp_protocol_version = client_tcp_protocol_version;
is_interserver_mode = (user == USER_INTERSERVER_MARKER);
if (is_interserver_mode)
{
receiveClusterNameAndSalt();
return;
}
session->authenticate(user, password, socket().peerAddress());
}
@ -1027,8 +1030,11 @@ bool TCPHandler::receivePacket()
{
case Protocol::Client::IgnoredPartUUIDs:
/// Part uuids packet if any comes before query.
if (!state.empty() || state.part_uuids_to_ignore)
receiveUnexpectedIgnoredPartUUIDs();
receiveIgnoredPartUUIDs();
return true;
case Protocol::Client::Query:
if (!state.empty())
receiveUnexpectedQuery();
@ -1037,8 +1043,10 @@ bool TCPHandler::receivePacket()
case Protocol::Client::Data:
case Protocol::Client::Scalar:
if (state.skipping_data)
return receiveUnexpectedData(false);
if (state.empty())
receiveUnexpectedData();
receiveUnexpectedData(true);
return receiveData(packet_type == Protocol::Client::Scalar);
case Protocol::Client::Ping:
@ -1049,10 +1057,9 @@ bool TCPHandler::receivePacket()
case Protocol::Client::Cancel:
{
/// For testing connection collector.
const Settings & settings = query_context->getSettingsRef();
if (settings.sleep_in_receive_cancel_ms.totalMilliseconds())
if (sleep_in_receive_cancel.totalMilliseconds())
{
std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds());
std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds());
std::this_thread::sleep_for(ms);
}
@ -1074,14 +1081,18 @@ bool TCPHandler::receivePacket()
}
}
void TCPHandler::receiveIgnoredPartUUIDs()
{
state.part_uuids = true;
std::vector<UUID> uuids;
readVectorBinary(uuids, *in);
readVectorBinary(state.part_uuids_to_ignore.emplace(), *in);
}
if (!uuids.empty())
query_context->getIgnoredPartUUIDs()->add(uuids);
void TCPHandler::receiveUnexpectedIgnoredPartUUIDs()
{
std::vector<UUID> skip_part_uuids;
readVectorBinary(skip_part_uuids, *in);
throw NetException("Unexpected packet IgnoredPartUUIDs received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
}
@ -1095,10 +1106,9 @@ String TCPHandler::receiveReadTaskResponseAssumeLocked()
{
state.is_cancelled = true;
/// For testing connection collector.
const Settings & settings = query_context->getSettingsRef();
if (settings.sleep_in_receive_cancel_ms.totalMilliseconds())
if (sleep_in_receive_cancel.totalMilliseconds())
{
std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds());
std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds());
std::this_thread::sleep_for(ms);
}
return {};
@ -1129,14 +1139,14 @@ void TCPHandler::receiveClusterNameAndSalt()
if (salt.empty())
throw NetException("Empty salt is not allowed", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
cluster_secret = query_context->getCluster(cluster)->getSecret();
cluster_secret = server.context()->getCluster(cluster)->getSecret();
}
catch (const Exception & e)
{
try
{
/// We try to send error information to the client.
sendException(e, connection_context->getSettingsRef().calculate_text_stack_trace);
sendException(e, send_exception_with_stack_trace);
}
catch (...) {}
@ -1152,25 +1162,11 @@ void TCPHandler::receiveQuery()
state.is_empty = false;
readStringBinary(state.query_id, *in);
/// Client info
ClientInfo & client_info = query_context->getClientInfo();
/// Read client info.
ClientInfo client_info = session->getClientInfo();
if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO)
client_info.read(*in, client_tcp_protocol_version);
/// For better support of old clients, that does not send ClientInfo.
if (client_info.query_kind == ClientInfo::QueryKind::NO_QUERY)
{
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.client_name = client_name;
client_info.client_version_major = client_version_major;
client_info.client_version_minor = client_version_minor;
client_info.client_version_patch = client_version_patch;
client_info.client_tcp_protocol_version = client_tcp_protocol_version;
}
/// Set fields, that are known apriori.
client_info.interface = ClientInfo::Interface::TCP;
/// Per query settings are also passed via TCP.
/// We need to check them before applying due to they can violate the settings constraints.
auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS)
@ -1191,12 +1187,11 @@ void TCPHandler::receiveQuery()
readVarUInt(compression, *in);
state.compression = static_cast<Protocol::Compression>(compression);
last_block_in.compression = state.compression;
readStringBinary(state.query, *in);
/// It is OK to check only when query != INITIAL_QUERY,
/// since only in that case the actions will be done.
if (!cluster.empty() && client_info.query_kind != ClientInfo::QueryKind::INITIAL_QUERY)
if (is_interserver_mode)
{
#if USE_SSL
std::string data(salt);
@ -1218,26 +1213,33 @@ void TCPHandler::receiveQuery()
/// i.e. when the INSERT is done with the global context (w/o user).
if (!client_info.initial_user.empty())
{
query_context->setUserWithoutCheckingPassword(client_info.initial_user, client_info.initial_address);
LOG_DEBUG(log, "User (initial): {}", query_context->getUserName());
LOG_DEBUG(log, "User (initial): {}", client_info.initial_user);
session->authenticate(AlwaysAllowCredentials{client_info.initial_user}, client_info.initial_address);
}
/// No need to update connection_context, since it does not requires user (it will not be used for query execution)
#else
throw Exception(
"Inter-server secret support is disabled, because ClickHouse was built without SSL library",
ErrorCodes::SUPPORT_IS_DISABLED);
#endif
}
else
{
query_context->setInitialRowPolicy();
}
query_context = session->makeQueryContext(std::move(client_info));
/// Sets the default database if it wasn't set earlier for the session context.
if (!default_database.empty() && !session->sessionContext())
query_context->setCurrentDatabase(default_database);
if (state.part_uuids_to_ignore)
query_context->getIgnoredPartUUIDs()->add(*state.part_uuids_to_ignore);
query_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); });
///
/// Settings
///
auto settings_changes = passed_settings.changes();
if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
auto query_kind = query_context->getClientInfo().query_kind;
if (query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{
/// Throw an exception if the passed settings violate the constraints.
query_context->checkSettingsConstraints(settings_changes);
@ -1249,40 +1251,24 @@ void TCPHandler::receiveQuery()
}
query_context->applySettingsChanges(settings_changes);
/// Use the received query id, or generate a random default. It is convenient
/// to also generate the default OpenTelemetry trace id at the same time, and
/// set the trace parent.
/// Notes:
/// 1) ClientInfo might contain upstream trace id, so we decide whether to use
/// the default ids after we have received the ClientInfo.
/// 2) There is the opentelemetry_start_trace_probability setting that
/// controls when we start a new trace. It can be changed via Native protocol,
/// so we have to apply the changes first.
query_context->setCurrentQueryId(state.query_id);
/// Disable function name normalization when it's a secondary query, because queries are either
/// already normalized on initiator node, or not normalized and should remain unnormalized for
/// compatibility.
if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY)
if (query_kind == ClientInfo::QueryKind::SECONDARY_QUERY)
{
query_context->setSetting("normalize_function_names", Field(0));
}
// Use the received query id, or generate a random default. It is convenient
// to also generate the default OpenTelemetry trace id at the same time, and
// set the trace parent.
// Why is this done here and not earlier:
// 1) ClientInfo might contain upstream trace id, so we decide whether to use
// the default ids after we have received the ClientInfo.
// 2) There is the opentelemetry_start_trace_probability setting that
// controls when we start a new trace. It can be changed via Native protocol,
// so we have to apply the changes first.
query_context->setCurrentQueryId(state.query_id);
// Set parameters of initial query.
if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{
/// 'Current' fields was set at receiveHello.
client_info.initial_user = client_info.current_user;
client_info.initial_query_id = client_info.current_query_id;
client_info.initial_address = client_info.current_address;
}
/// Sync timeouts on client and server during current query to avoid dangling queries on server
/// NOTE: We use settings.send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter),
/// because settings.send_timeout is client-side setting which has opposite meaning on the server side.
/// NOTE: these settings are applied only for current connection (not for distributed tables' connections)
const Settings & settings = query_context->getSettingsRef();
state.timeout_setter = std::make_unique<TimeoutSetter>(socket(), settings.receive_timeout, settings.send_timeout);
}
void TCPHandler::receiveUnexpectedQuery()
@ -1307,7 +1293,10 @@ void TCPHandler::receiveUnexpectedQuery()
readStringBinary(skip_hash, *in, 32);
readVarUInt(skip_uint_64, *in);
readVarUInt(skip_uint_64, *in);
last_block_in.compression = static_cast<Protocol::Compression>(skip_uint_64);
readStringBinary(skip_string, *in);
throw NetException("Unexpected packet Query received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
@ -1324,73 +1313,77 @@ bool TCPHandler::receiveData(bool scalar)
/// Read one block from the network and write it down
Block block = state.block_in->read();
if (block)
if (!block)
{
if (scalar)
{
/// Scalar value
query_context->addScalar(temporary_id.table_name, block);
}
else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input)
{
/// Data for external tables
state.read_all_data = true;
return false;
}
auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal);
StoragePtr storage;
/// If such a table does not exist, create it.
if (resolved)
{
storage = DatabaseCatalog::instance().getTable(resolved, query_context);
}
else
{
NamesAndTypesList columns = block.getNamesAndTypesList();
auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {});
storage = temporary_table.getTable();
query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table));
}
auto metadata_snapshot = storage->getInMemoryMetadataPtr();
/// The data will be written directly to the table.
auto temporary_table_out = std::make_shared<PushingToSinkBlockOutputStream>(storage->write(ASTPtr(), metadata_snapshot, query_context));
temporary_table_out->write(block);
temporary_table_out->writeSuffix();
if (scalar)
{
/// Scalar value
query_context->addScalar(temporary_id.table_name, block);
}
else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input)
{
/// Data for external tables
}
else if (state.need_receive_data_for_input)
auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal);
StoragePtr storage;
/// If such a table does not exist, create it.
if (resolved)
{
/// 'input' table function.
state.block_for_input = block;
storage = DatabaseCatalog::instance().getTable(resolved, query_context);
}
else
{
/// INSERT query.
state.io.out->write(block);
NamesAndTypesList columns = block.getNamesAndTypesList();
auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {});
storage = temporary_table.getTable();
query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table));
}
return true;
auto metadata_snapshot = storage->getInMemoryMetadataPtr();
/// The data will be written directly to the table.
auto temporary_table_out = std::make_shared<PushingToSinkBlockOutputStream>(storage->write(ASTPtr(), metadata_snapshot, query_context));
temporary_table_out->write(block);
temporary_table_out->writeSuffix();
}
else if (state.need_receive_data_for_input)
{
/// 'input' table function.
state.block_for_input = block;
}
else
return false;
{
/// INSERT query.
state.io.out->write(block);
}
return true;
}
void TCPHandler::receiveUnexpectedData()
bool TCPHandler::receiveUnexpectedData(bool throw_exception)
{
String skip_external_table_name;
readStringBinary(skip_external_table_name, *in);
std::shared_ptr<ReadBuffer> maybe_compressed_in;
if (last_block_in.compression == Protocol::Compression::Enable)
maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true);
else
maybe_compressed_in = in;
auto skip_block_in = std::make_shared<NativeBlockInputStream>(
*maybe_compressed_in,
last_block_in.header,
client_tcp_protocol_version);
auto skip_block_in = std::make_shared<NativeBlockInputStream>(*maybe_compressed_in, client_tcp_protocol_version);
bool read_ok = skip_block_in->read();
skip_block_in->read();
throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
if (!read_ok)
state.read_all_data = true;
if (throw_exception)
throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
return read_ok;
}
void TCPHandler::initBlockInput()
@ -1411,9 +1404,6 @@ void TCPHandler::initBlockInput()
else if (state.need_receive_data_for_input)
header = state.input_header;
last_block_in.header = header;
last_block_in.compression = state.compression;
state.block_in = std::make_shared<NativeBlockInputStream>(
*state.maybe_compressed_in,
header,
@ -1426,10 +1416,9 @@ void TCPHandler::initBlockOutput(const Block & block)
{
if (!state.block_out)
{
const Settings & query_settings = query_context->getSettingsRef();
if (!state.maybe_compressed_out)
{
const Settings & query_settings = query_context->getSettingsRef();
std::string method = Poco::toUpper(query_settings.network_compression_method.toString());
std::optional<int> level;
if (method == "ZSTD")
@ -1450,7 +1439,7 @@ void TCPHandler::initBlockOutput(const Block & block)
*state.maybe_compressed_out,
client_tcp_protocol_version,
block.cloneEmpty(),
!connection_context->getSettingsRef().low_cardinality_allow_in_native_format);
!query_settings.low_cardinality_allow_in_native_format);
}
}
@ -1459,11 +1448,12 @@ void TCPHandler::initLogsBlockOutput(const Block & block)
if (!state.logs_block_out)
{
/// Use uncompressed stream since log blocks usually contain only one row
const Settings & query_settings = query_context->getSettingsRef();
state.logs_block_out = std::make_shared<NativeBlockOutputStream>(
*out,
client_tcp_protocol_version,
block.cloneEmpty(),
!connection_context->getSettingsRef().low_cardinality_allow_in_native_format);
!query_settings.low_cardinality_allow_in_native_format);
}
}
@ -1473,7 +1463,7 @@ bool TCPHandler::isQueryCancelled()
if (state.is_cancelled || state.sent_all_data)
return true;
if (after_check_cancelled.elapsed() / 1000 < query_context->getSettingsRef().interactive_delay)
if (after_check_cancelled.elapsed() / 1000 < interactive_delay)
return false;
after_check_cancelled.restart();
@ -1501,10 +1491,9 @@ bool TCPHandler::isQueryCancelled()
state.is_cancelled = true;
/// For testing connection collector.
{
const Settings & settings = query_context->getSettingsRef();
if (settings.sleep_in_receive_cancel_ms.totalMilliseconds())
if (sleep_in_receive_cancel.totalMilliseconds())
{
std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds());
std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds());
std::this_thread::sleep_for(ms);
}
}
@ -1542,11 +1531,10 @@ void TCPHandler::sendData(const Block & block)
writeStringBinary("", *out);
/// For testing hedged requests
const Settings & settings = query_context->getSettingsRef();
if (block.rows() > 0 && settings.sleep_in_send_data_ms.totalMilliseconds())
if (block.rows() > 0 && query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds())
{
out->next();
std::chrono::milliseconds ms(settings.sleep_in_send_data_ms.totalMilliseconds());
std::chrono::milliseconds ms(query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds());
std::this_thread::sleep_for(ms);
}

View File

@ -11,7 +11,7 @@
#include <IO/TimeoutSetter.h>
#include <DataStreams/BlockIO.h>
#include <Interpreters/InternalTextLogsQueue.h>
#include <Interpreters/Context.h>
#include <Interpreters/Context_fwd.h>
#include "IServer.h"
@ -26,7 +26,10 @@ namespace Poco { class Logger; }
namespace DB
{
class Session;
struct Settings;
class ColumnsDescription;
struct BlockStreamProfileInfo;
/// State of query processing.
struct QueryState
@ -65,11 +68,11 @@ struct QueryState
bool sent_all_data = false;
/// Request requires data from the client (INSERT, but not INSERT SELECT).
bool need_receive_data_for_insert = false;
/// Temporary tables read
bool temporary_tables_read = false;
/// Data was read.
bool read_all_data = false;
/// A state got uuids to exclude from a query
bool part_uuids = false;
std::optional<std::vector<UUID>> part_uuids_to_ignore;
/// Request requires data from client for function input()
bool need_receive_data_for_input = false;
@ -78,6 +81,9 @@ struct QueryState
/// sample block from StorageInput
Block input_header;
/// If true, the data packets will be skipped instead of reading. Used to recover after errors.
bool skipping_data = false;
/// To output progress, the difference after the previous sending of progress.
Progress progress;
@ -99,7 +105,6 @@ struct QueryState
struct LastBlockInputParameters
{
Protocol::Compression compression = Protocol::Compression::Disable;
Block header;
};
class TCPHandler : public Poco::Net::TCPServerConnection
@ -132,10 +137,19 @@ private:
UInt64 client_version_patch = 0;
UInt64 client_tcp_protocol_version = 0;
ContextMutablePtr connection_context;
ContextMutablePtr query_context;
/// Connection settings, which are extracted from a context.
bool send_exception_with_stack_trace = true;
Poco::Timespan send_timeout = DBMS_DEFAULT_SEND_TIMEOUT_SEC;
Poco::Timespan receive_timeout = DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC;
UInt64 poll_interval = DBMS_DEFAULT_POLL_INTERVAL;
UInt64 idle_connection_timeout = 3600;
UInt64 interactive_delay = 100000;
Poco::Timespan sleep_in_send_tables_status;
UInt64 unknown_packet_in_send_data = 0;
Poco::Timespan sleep_in_receive_cancel;
size_t unknown_packet_in_send_data = 0;
std::unique_ptr<Session> session;
ContextMutablePtr query_context;
/// Streams for reading/writing from/to client connection socket.
std::shared_ptr<ReadBuffer> in;
@ -148,6 +162,7 @@ private:
String default_database;
/// For inter-server secret (remote_server.*.secret)
bool is_interserver_mode = false;
String salt;
String cluster;
String cluster_secret;
@ -167,6 +182,8 @@ private:
void runImpl();
void extractConnectionSettingsFromContext(const ContextPtr & context);
bool receiveProxyHeader();
void receiveHello();
bool receivePacket();
@ -174,18 +191,19 @@ private:
void receiveIgnoredPartUUIDs();
String receiveReadTaskResponseAssumeLocked();
bool receiveData(bool scalar);
bool readDataNext(size_t poll_interval, time_t receive_timeout);
void readData(const Settings & connection_settings);
bool readDataNext();
void readData();
void skipData();
void receiveClusterNameAndSalt();
std::tuple<size_t, int> getReadTimeouts(const Settings & connection_settings);
[[noreturn]] void receiveUnexpectedData();
bool receiveUnexpectedData(bool throw_exception = true);
[[noreturn]] void receiveUnexpectedQuery();
[[noreturn]] void receiveUnexpectedIgnoredPartUUIDs();
[[noreturn]] void receiveUnexpectedHello();
[[noreturn]] void receiveUnexpectedTablesStatusRequest();
/// Process INSERT query
void processInsertQuery(const Settings & connection_settings);
void processInsertQuery();
/// Process a request that does not require the receiving of data blocks from the client
void processOrdinaryQuery();

View File

@ -24,3 +24,4 @@ def test_different_versions(start_cluster):
node.query("SELECT 1", settings={'max_concurrent_queries_for_user': 1})
assert node.contains_in_log('Too many simultaneous queries for user')
assert not node.contains_in_log('Unknown packet')
assert not node.contains_in_log('Unexpected packet')

View File

@ -1,8 +1,20 @@
===http===
{"query":"select 1 from remote('127.0.0.2', system, one) format Null\n","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"DESC TABLE system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"DESC TABLE system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"SELECT 1 FROM system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1}
{"query":"DESC TABLE system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"query":"DESC TABLE system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"query":"SELECT 1 FROM system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"query":"select 1 from remote('127.0.0.2', system, one) format Null\n","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1}
{"total spans":"4","unique spans":"4","unique non-zero parent spans":"3"}
{"initial query spans with proper parent":"1"}
{"unique non-empty tracestate values":"1"}
===native===
{"query":"select * from url('http:\/\/127.0.0.2:8123\/?query=select%201%20format%20Null', CSV, 'a int')","status":"QueryFinish","tracestate":"another custom state","sorted_by_start_time":1}
{"query":"select 1 format Null\n","status":"QueryFinish","tracestate":"another custom state","sorted_by_start_time":1}
{"query":"select 1 format Null\n","query_status":"QueryFinish","tracestate":"another custom state","sorted_by_finish_time":1}
{"query":"select * from url('http:\/\/127.0.0.2:8123\/?query=select%201%20format%20Null', CSV, 'a int')","query_status":"QueryFinish","tracestate":"another custom state","sorted_by_finish_time":1}
{"total spans":"2","unique spans":"2","unique non-zero parent spans":"2"}
{"initial query spans with proper parent":"1"}
{"unique non-empty tracestate values":"1"}

View File

@ -12,6 +12,28 @@ function check_log
${CLICKHOUSE_CLIENT} --format=JSONEachRow -nq "
system flush logs;
-- Show queries sorted by start time.
select attribute['db.statement'] as query,
attribute['clickhouse.query_status'] as status,
attribute['clickhouse.tracestate'] as tracestate,
1 as sorted_by_start_time
from system.opentelemetry_span_log
where trace_id = reinterpretAsUUID(reverse(unhex('$trace_id')))
and operation_name = 'query'
order by start_time_us
;
-- Show queries sorted by finish time.
select attribute['db.statement'] as query,
attribute['clickhouse.query_status'] as query_status,
attribute['clickhouse.tracestate'] as tracestate,
1 as sorted_by_finish_time
from system.opentelemetry_span_log
where trace_id = reinterpretAsUUID(reverse(unhex('$trace_id')))
and operation_name = 'query'
order by finish_time_us
;
-- Check the number of query spans with given trace id, to verify it was
-- propagated.
select count(*) "'"'"total spans"'"'",
@ -89,10 +111,10 @@ check_log
echo "===sampled==="
query_id=$(${CLICKHOUSE_CLIENT} -q "select lower(hex(reverse(reinterpretAsString(generateUUIDv4()))))")
for i in {1..200}
for i in {1..20}
do
${CLICKHOUSE_CLIENT} \
--opentelemetry_start_trace_probability=0.1 \
--opentelemetry_start_trace_probability=0.5 \
--query_id "$query_id-$i" \
--query "select 1 from remote('127.0.0.2', system, one) format Null" \
&
@ -108,8 +130,8 @@ wait
${CLICKHOUSE_CLIENT} -q "system flush logs"
${CLICKHOUSE_CLIENT} -q "
-- expect 200 * 0.1 = 20 sampled events on average
select if(count() > 1 and count() < 50, 'OK', 'Fail')
-- expect 20 * 0.5 = 10 sampled events on average
select if(2 <= count() and count() <= 18, 'OK', 'Fail')
from system.opentelemetry_span_log
where operation_name = 'query'
and parent_span_id = 0 -- only account for the initial queries